utils.py 35.7 KB
Newer Older
1
import argparse
2
import asyncio
3
import contextlib
4
import datetime
Woosuk Kwon's avatar
Woosuk Kwon committed
5
import enum
6
import gc
7
import os
8
import socket
9
import subprocess
10
import sys
11
12
import tempfile
import threading
Zhuohan Li's avatar
Zhuohan Li committed
13
import uuid
14
import warnings
15
from asyncio import FIRST_COMPLETED, ensure_future
16
from collections import defaultdict
17
from functools import lru_cache, partial, wraps
18
from platform import uname
19
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
20
21
                    Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
                    Type, TypeVar, Union, overload)
22
from uuid import uuid4
Zhuohan Li's avatar
Zhuohan Li committed
23

24
import numpy as np
25
import numpy.typing as npt
26
import psutil
Zhuohan Li's avatar
Zhuohan Li committed
27
import torch
28
import torch.types
29
from typing_extensions import ParamSpec, TypeIs, assert_never
30

31
import vllm.envs as envs
32
from vllm import _custom_ops as ops
33
from vllm.logger import enable_trace_function_call, init_logger
34
35
36

logger = init_logger(__name__)

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Exception strings for non-implemented encoder/decoder scenarios

STR_NOT_IMPL_ENC_DEC_SWA = \
    "Sliding window attention for encoder/decoder models " + \
                    "is not currently supported."

STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
    "Prefix caching for encoder/decoder models " + \
                    "is not currently supported."

STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
    "Chunked prefill for encoder/decoder models " + \
                    "is not currently supported."

STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
    "Models with logits_soft_cap "
    "require FlashInfer backend, which is "
    "currently not supported for encoder/decoder "
    "models.")

STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
                             "supported with encoder/decoder "
                             "models.")

STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
                           "currently supported with "
                           "encoder/decoder models.")

STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
                           "supported with encoder/decoder "
                           "models.")

STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
                                 "currently supported with encoder/"
                                 "decoder models.")

STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
                                  "currently supported with encoder/"
                                  "decoder models.")

STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
                                "currently supported with encoder/"
                                "decoder models.")

STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
                                       "currently supported with encoder/"
                                       "decoder models.")

# Efficiently import all enc/dec error strings
# rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
    "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
    "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
    "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
    STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
    "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
    "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
    "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
    "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
    "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
    "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
    "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
    "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
}

# Constants related to forcing the attention backend selection

# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"

# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

118
119
120
GiB_bytes = 1 << 30
"""The number of bytes in one gibibyte (GiB)."""

121
122
123
124
STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
125
    "fp8": torch.uint8,
126
127
    "fp8_e4m3": torch.uint8,
    "fp8_e5m2": torch.uint8,
128
}
Zhuohan Li's avatar
Zhuohan Li committed
129

130
131
132
133
134
135
136
137
138
TORCH_DTYPE_TO_NUMPY_DTYPE = {
    torch.float16: np.float16,
    torch.float32: np.float32,
    torch.float64: np.float64,
    torch.uint8: np.uint8,
    torch.int32: np.int32,
    torch.int64: np.int64,
}

139
140
141
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
142
U = TypeVar("U")
143

Woosuk Kwon's avatar
Woosuk Kwon committed
144

145
146
147
148
149
150
151
class _Sentinel:
    ...


ALL_PINNED_SENTINEL = _Sentinel()


Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
154
155
156
157
158
159
160
161
class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


class Counter:

    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
162
    def __next__(self) -> int:
163
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
164
        self.counter += 1
165
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167
168

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
169

170

171
class LRUCache(Generic[T]):
172
173

    def __init__(self, capacity: int):
174
        self.cache: OrderedDict[Hashable, T] = OrderedDict()
175
        self.pinned_items: Set[Hashable] = set()
176
177
178
179
180
181
182
183
        self.capacity = capacity

    def __contains__(self, key: Hashable) -> bool:
        return key in self.cache

    def __len__(self) -> int:
        return len(self.cache)

184
185
186
187
    def __getitem__(self, key: Hashable) -> T:
        value = self.cache[key]  # Raise KeyError if not exists
        self.cache.move_to_end(key)
        return value
188

189
    def __setitem__(self, key: Hashable, value: T) -> None:
190
191
192
193
194
195
196
197
        self.put(key, value)

    def __delitem__(self, key: Hashable) -> None:
        self.pop(key)

    def touch(self, key: Hashable) -> None:
        self.cache.move_to_end(key)

198
199
200
    def get(self,
            key: Hashable,
            default_value: Optional[T] = None) -> Optional[T]:
201
        value: Optional[T]
202
        if key in self.cache:
203
            value = self.cache[key]
204
205
206
207
208
            self.cache.move_to_end(key)
        else:
            value = default_value
        return value

209
    def put(self, key: Hashable, value: T) -> None:
210
211
212
213
        self.cache[key] = value
        self.cache.move_to_end(key)
        self._remove_old_if_needed()

214
215
216
217
218
219
220
221
222
223
224
225
    def pin(self, key: Hashable) -> None:
        """
        Pins a key in the cache preventing it from being
        evicted in the LRU order.
        """
        if key not in self.cache:
            raise ValueError(f"Cannot pin key: {key} not in cache.")
        self.pinned_items.add(key)

    def _unpin(self, key: Hashable) -> None:
        self.pinned_items.remove(key)

226
    def _on_remove(self, key: Hashable, value: Optional[T]):
227
228
        pass

229
    def remove_oldest(self, remove_pinned=False):
230
231
        if not self.cache:
            return
232
233
234
235
236
237
238
239
240
241
242
243

        if not remove_pinned:
            # pop the oldest item in the cache that is not pinned
            lru_key = next(
                (key for key in self.cache if key not in self.pinned_items),
                ALL_PINNED_SENTINEL)
            if lru_key is ALL_PINNED_SENTINEL:
                raise RuntimeError("All items are pinned, "
                                   "cannot remove oldest from the cache.")
        else:
            lru_key = next(iter(self.cache))
        self.pop(lru_key)
244
245
246
247
248

    def _remove_old_if_needed(self) -> None:
        while len(self.cache) > self.capacity:
            self.remove_oldest()

249
250
251
    def pop(self,
            key: Hashable,
            default_value: Optional[T] = None) -> Optional[T]:
252
        run_on_remove = key in self.cache
253
        value: Optional[T] = self.cache.pop(key, default_value)
254
255
256
        # remove from pinned items
        if key in self.pinned_items:
            self._unpin(key)
257
258
259
260
261
262
        if run_on_remove:
            self._on_remove(key, value)
        return value

    def clear(self):
        while len(self.cache) > 0:
263
            self.remove_oldest(remove_pinned=True)
264
265
266
        self.cache.clear()


267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
class PyObjectCache:
    """Used to cache python objects to avoid object allocations 
    across scheduler iterations.
    """

    def __init__(self, obj_builder):
        self._obj_builder = obj_builder
        self._index = 0

        self._obj_cache = []
        for _ in range(128):
            self._obj_cache.append(self._obj_builder())

    def _grow_cache(self):
        # Double the size of the cache
        num_objs = len(self._obj_cache)
        for _ in range(num_objs):
            self._obj_cache.append(self._obj_builder())

    def get_object(self):
        """Returns a pre-allocated cached object. If there is not enough 
        objects, then the cache size will double.
        """
        if self._index >= len(self._obj_cache):
            self._grow_cache()
            assert self._index < len(self._obj_cache)

        obj = self._obj_cache[self._index]
        self._index += 1

        return obj

    def reset(self):
        """Makes all cached-objects available for the next scheduler iteration.
        """
        self._index = 0


305
306
307
308
def is_hip() -> bool:
    return torch.version.hip is not None


309
310
@lru_cache(maxsize=None)
def is_cpu() -> bool:
311
312
313
314
315
    from importlib.metadata import PackageNotFoundError, version
    try:
        return "cpu" in version("vllm")
    except PackageNotFoundError:
        return False
316
317


318
319
320
321
322
323
324
325
326
@lru_cache(maxsize=None)
def is_openvino() -> bool:
    from importlib.metadata import PackageNotFoundError, version
    try:
        return "openvino" in version("vllm")
    except PackageNotFoundError:
        return False


327
@lru_cache(maxsize=None)
328
329
330
331
332
333
334
335
def is_neuron() -> bool:
    try:
        import transformers_neuronx
    except ImportError:
        transformers_neuronx = None
    return transformers_neuronx is not None


336
337
338
339
340
341
342
343
344
@lru_cache(maxsize=None)
def is_tpu() -> bool:
    try:
        import libtpu
    except ImportError:
        libtpu = None
    return libtpu is not None


345
346
@lru_cache(maxsize=None)
def is_xpu() -> bool:
347
348
349
350
351
    from importlib.metadata import PackageNotFoundError, version
    try:
        is_xpu_flag = "xpu" in version("vllm")
    except PackageNotFoundError:
        return False
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    # vllm is not build with xpu
    if not is_xpu_flag:
        return False
    try:
        import intel_extension_for_pytorch as ipex  # noqa: F401
        _import_ipex = True
    except ImportError as e:
        logger.warning("Import Error for IPEX: %s", e.msg)
        _import_ipex = False
    # ipex dependency is not ready
    if not _import_ipex:
        logger.warning("not found ipex lib")
        return False
    return hasattr(torch, "xpu") and torch.xpu.is_available()


368
@lru_cache(maxsize=None)
369
370
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
    """Returns the maximum shared memory per thread block in bytes."""
371
    max_shared_mem = (
372
        ops.get_max_shared_memory_per_block_device_attribute(gpu))
373
374
    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
    # will fail
375
    assert max_shared_mem > 0, "max_shared_mem can not be zero"
376
377
378
    return int(max_shared_mem)


379
def get_cpu_memory() -> int:
380
    """Returns the total CPU memory of the node in bytes."""
381
    return psutil.virtual_memory().total
Zhuohan Li's avatar
Zhuohan Li committed
382
383
384
385


def random_uuid() -> str:
    return str(uuid.uuid4().hex)
386

387

388
@lru_cache(maxsize=None)
389
def get_vllm_instance_id() -> str:
390
391
392
393
394
395
    """
    If the environment variable VLLM_INSTANCE_ID is set, return it.
    Otherwise, return a random UUID.
    Instance id represents an instance of the VLLM. All processes in the same
    instance should have the same instance id.
    """
396
    return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}"
397
398


399
@lru_cache(maxsize=None)
400
401
402
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()
403
404


405
def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
406
407
408
409
410
411
412
    """Take a blocking function, and run it on in an executor thread.

    This function prevents the blocking function from blocking the
    asyncio event loop.
    The code in this function needs to be thread safe.
    """

413
    def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
414
415
416
417
418
419
420
        loop = asyncio.get_event_loop()
        p_func = partial(func, *args, **kwargs)
        return loop.run_in_executor(executor=None, func=p_func)

    return _async_wrapper


421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
async def iterate_with_cancellation(
    iterator: AsyncGenerator[T, None],
    is_cancelled: Callable[[], Awaitable[bool]],
) -> AsyncGenerator[T, None]:
    """Convert async iterator into one that polls the provided function
    at least once per second to check for client cancellation.
    """

    # Can use anext() in python >= 3.10
    awaits = [ensure_future(iterator.__anext__())]
    while True:
        done, pending = await asyncio.wait(awaits, timeout=1)
        if await is_cancelled():
            with contextlib.suppress(BaseException):
                awaits[0].cancel()
                await iterator.aclose()
            raise asyncio.CancelledError("client cancelled")
        if done:
            try:
                item = await awaits[0]
                awaits[0] = ensure_future(iterator.__anext__())
                yield item
            except StopAsyncIteration:
                # we are done
                return
446
447


448
449
async def merge_async_iterators(
    *iterators: AsyncGenerator[T, None],
450
    is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None,
451
) -> AsyncGenerator[Tuple[int, T], None]:
452
453
454
455
456
    """Merge multiple asynchronous iterators into a single iterator.

    This method handle the case where some iterators finish before others.
    When it yields, it yields a tuple (i, item) where i is the index of the
    iterator that yields the item.
457

458
459
    It also optionally polls a provided function at least once per second
    to check for client cancellation.
460
    """
461
462
463
464
465
466

    # Can use anext() in python >= 3.10
    awaits = {
        ensure_future(pair[1].__anext__()): pair
        for pair in enumerate(iterators)
    }
467
    timeout = None if is_cancelled is None else 1
468
469
470
471
    try:
        while awaits:
            done, pending = await asyncio.wait(awaits.keys(),
                                               return_when=FIRST_COMPLETED,
472
473
                                               timeout=timeout)
            if is_cancelled is not None and await is_cancelled():
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
                raise asyncio.CancelledError("client cancelled")
            for d in done:
                pair = awaits.pop(d)
                try:
                    item = await d
                    i, it = pair
                    awaits[ensure_future(it.__anext__())] = pair
                    yield i, item
                except StopAsyncIteration:
                    pass
    finally:
        # Cancel any remaining iterators
        for f, (_, it) in awaits.items():
            with contextlib.suppress(BaseException):
                f.cancel()
                await it.aclose()
490
491


492
def get_ip() -> str:
493
    host_ip = envs.VLLM_HOST_IP
494
495
496
497
498
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

499
    # try ipv4
500
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
501
    try:
502
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
503
        return s.getsockname()[0]
504
505
506
507
508
    except Exception:
        pass

    # try ipv6
    try:
509
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
510
511
512
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
513
        return s.getsockname()[0]
514
515
516
517
518
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
519
520
        "The value can be set by the environment variable"
        " VLLM_HOST_IP or HOST_IP.",
521
522
        stacklevel=2)
    return "0.0.0.0"
523
524


525
def get_distributed_init_method(ip: str, port: int) -> str:
526
527
528
    # Brackets are not permitted in ipv4 addresses,
    # see https://github.com/python/cpython/issues/103848
    return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
529
530


531
532
533
534
535
536
537
def get_open_zmq_ipc_path() -> str:
    base_rpc_path = envs.VLLM_RPC_BASE_PATH
    return f"ipc://{base_rpc_path}/{uuid4()}"


def get_open_port() -> int:
    port = envs.VLLM_PORT
538
    if port is not None:
539
540
541
542
543
544
545
546
547
        while True:
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("", port))
                    return port
            except OSError:
                port += 1  # Increment port number if already in use
                logger.info("Port %d is already in use, trying port %d",
                            port - 1, port)
548
549
550
551
552
553
554
555
556
557
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
558
559


560
561
def update_environment_variables(envs: Dict[str, str]):
    for k, v in envs.items():
562
        if k in os.environ and os.environ[k] != v:
563
564
565
            logger.warning(
                "Overwriting environment variable %s "
                "from '%s' to '%s'", k, os.environ[k], v)
566
        os.environ[k] = v
567
568


569
def chunk_list(lst: List[T], chunk_size: int):
570
    """Yield successive chunk_size chunks from lst."""
571
572
    for i in range(0, len(lst), chunk_size):
        yield lst[i:i + chunk_size]
573
574
575
576
577
578
579


def cdiv(a: int, b: int) -> int:
    """Ceiling division."""
    return -(a // -b)


580
def _generate_random_fp8(
581
    tensor: torch.Tensor,
582
583
584
585
586
587
    low: float,
    high: float,
) -> None:
    # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
    # it may occur Inf or NaN if we directly use torch.randint
    # to generate random data for fp8 data.
588
    # For example, s.11111.00 in fp8e5m2 format represents Inf.
589
590
591
592
    #     | E4M3        | E5M2
    #-----|-------------|-------------------
    # Inf | N/A         | s.11111.00
    # NaN | s.1111.111  | s.11111.{01,10,11}
593
    from vllm import _custom_ops as ops
594
595
    tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
    tensor_tmp.uniform_(low, high)
596
    ops.convert_fp8(tensor, tensor_tmp)
597
598
599
    del tensor_tmp


600
601
602
def get_kv_cache_torch_dtype(
        cache_dtype: Optional[Union[str, torch.dtype]],
        model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
603
604
605
606
607
608
609
610
611
612
    if isinstance(cache_dtype, str):
        if cache_dtype == "auto":
            if isinstance(model_dtype, str):
                torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
            elif isinstance(model_dtype, torch.dtype):
                torch_dtype = model_dtype
            else:
                raise ValueError(f"Invalid model dtype: {model_dtype}")
        elif cache_dtype in ["half", "bfloat16", "float"]:
            torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
613
        elif cache_dtype == "fp8":
614
615
616
617
618
619
620
            torch_dtype = torch.uint8
        else:
            raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
    elif isinstance(cache_dtype, torch.dtype):
        torch_dtype = cache_dtype
    else:
        raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
    return torch_dtype


def create_kv_caches_with_random_flash(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
    seed: int = 0,
    device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
    key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
    scale = head_size**-0.5
642
643
644
645

    key_caches: List[torch.Tensor] = []
    value_caches: List[torch.Tensor] = []

646
647
648
649
    for _ in range(num_layers):
        key_value_cache = torch.empty(size=key_value_cache_shape,
                                      dtype=torch_dtype,
                                      device=device)
650
651
652
653
654
655
656
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
            key_value_cache.uniform_(-scale, scale)
        elif cache_dtype == 'fp8':
            _generate_random_fp8(key_value_cache, -scale, scale)
        else:
            raise ValueError(
                f"Does not support key cache of type {cache_dtype}")
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        key_caches.append(key_value_cache[:, 0])
        value_caches.append(key_value_cache[:, 1])
    return key_caches, value_caches


def create_kv_caches_with_random(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
    seed: int = 0,
    device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
Joe's avatar
Joe committed
673
674
675
676
677
678

    if cache_dtype == "fp8" and head_size % 16:
        raise ValueError(
            f"Does not support key cache of type fp8 with head_size {head_size}"
        )

679
680
681
682
683
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
684
685
686
687

    scale = head_size**-0.5
    x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
688
    key_caches: List[torch.Tensor] = []
689
690
691
692
    for _ in range(num_layers):
        key_cache = torch.empty(size=key_cache_shape,
                                dtype=torch_dtype,
                                device=device)
693
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
694
            key_cache.uniform_(-scale, scale)
695
696
        elif cache_dtype == 'fp8':
            _generate_random_fp8(key_cache, -scale, scale)
697
698
699
        else:
            raise ValueError(
                f"Does not support key cache of type {cache_dtype}")
700
701
702
        key_caches.append(key_cache)

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
703
    value_caches: List[torch.Tensor] = []
704
705
706
707
    for _ in range(num_layers):
        value_cache = torch.empty(size=value_cache_shape,
                                  dtype=torch_dtype,
                                  device=device)
708
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
709
            value_cache.uniform_(-scale, scale)
710
711
        elif cache_dtype == 'fp8':
            _generate_random_fp8(value_cache, -scale, scale)
712
713
714
        else:
            raise ValueError(
                f"Does not support value cache of type {cache_dtype}")
715
716
        value_caches.append(value_cache)
    return key_caches, value_caches
717
718


719
720
721
722
723
724
725
726
727
728
729
730
731
732
@lru_cache
def print_warning_once(msg: str) -> None:
    logger.warning(msg)


@lru_cache(maxsize=None)
def is_pin_memory_available() -> bool:

    if in_wsl():
        # Pinning memory in WSL is not supported.
        # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
        print_warning_once("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
        return False
733
734
735
    elif is_xpu():
        print_warning_once("Pin memory is not supported on XPU.")
        return False
736
737
738
    elif is_neuron():
        print_warning_once("Pin memory is not supported on Neuron.")
        return False
739
    elif is_cpu() or is_openvino():
740
        return False
741
742
743
744
    return True


class CudaMemoryProfiler:
745

746
    def __init__(self, device: Optional[torch.types.Device] = None):
747
748
749
750
        self.device = device

    def current_memory_usage(self) -> float:
        # Return the memory usage in bytes.
751
752
753
754
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats(self.device)
            mem = torch.cuda.max_memory_allocated(self.device)
        elif is_xpu():
755
756
            torch.xpu.reset_peak_memory_stats(self.device)  # type: ignore
            mem = torch.xpu.max_memory_allocated(self.device)  # type: ignore
757
758
759
760
761
762
763
764
765
766
767
768
769
        return mem

    def __enter__(self):
        self.initial_memory = self.current_memory_usage()
        # This allows us to call methods of the context manager if needed
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.final_memory = self.current_memory_usage()
        self.consumed_memory = self.final_memory - self.initial_memory

        # Force garbage collection
        gc.collect()
770
771


772
def str_to_int_tuple(s: str) -> Tuple[int, ...]:
773
774
775
776
777
778
779
780
781
    """Convert a string to a tuple of integers."""
    try:
        return tuple(map(int, s.split(",")))
    except ValueError as e:
        raise ValueError(
            "String must be a series of integers separated by commas "
            f"(e.g., 1, 2, 3). Given input: {s}") from e


782
783
784
785
786
787
788
789
790
def make_ndarray_with_pad(
    x: List[List[T]],
    pad: T,
    dtype: npt.DTypeLike,
    *,
    max_len: Optional[int] = None,
) -> npt.NDArray:
    """
    Make a padded array from 2D inputs.
791
792
793
794

    The padding is applied to the end of each inner list until it reaches
    `max_len`.
    """
795
796
797
798
799
    if max_len is None:
        # Unlike for most functions, map is faster than a genexpr over `len`
        max_len = max(map(len, x), default=0)

    padded_x = np.full((len(x), max_len), pad, dtype=dtype)
800
801
802
    for ind, blocktb in enumerate(x):
        assert len(blocktb) <= max_len
        padded_x[ind, :len(blocktb)] = blocktb
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829

    return padded_x


def make_tensor_with_pad(
    x: List[List[T]],
    pad: T,
    dtype: torch.dtype,
    *,
    max_len: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    pin_memory: bool = False,
) -> torch.Tensor:
    """
    Make a padded tensor from 2D inputs.

    The padding is applied to the end of each inner list until it reaches
    `max_len`.
    """
    np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
    padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)

    tensor = torch.from_numpy(padded_x).to(device)
    if pin_memory:
        tensor = tensor.pin_memory()

    return tensor
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849


def async_tensor_h2d(
    data: list,
    dtype: torch.dtype,
    target_device: Union[str, torch.device],
    pin_memory: bool,
) -> torch.Tensor:
    """Asynchronously create a tensor and copy it from host to device."""
    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
    return t.to(device=target_device, non_blocking=True)


def maybe_expand_dim(tensor: torch.Tensor,
                     target_dims: int,
                     size: int = 1) -> torch.Tensor:
    """Expand the tensor to the target_dims."""
    if tensor.ndim < target_dims:
        tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
    return tensor
850
851


852
853
854
855
856
def get_dtype_size(dtype: torch.dtype) -> int:
    """Get the size of the data type in bytes."""
    return torch.tensor([], dtype=dtype).element_size()


857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
# `collections` helpers
def is_list_of(
    value: object,
    typ: Type[T],
    *,
    check: Literal["first", "all"] = "first",
) -> TypeIs[List[T]]:
    if not isinstance(value, list):
        return False

    if check == "first":
        return len(value) == 0 or isinstance(value[0], typ)
    elif check == "all":
        return all(isinstance(v, typ) for v in value)

    assert_never(check)


875
876
def merge_dicts(dict1: Dict[K, List[T]],
                dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
877
    """Merge 2 dicts that have key -> List of items.
878

879
880
    When a key conflicts, the values in dict1 is prioritized.
    """
881
    merged_dict: Dict[K, List[T]] = defaultdict(list)
882
883
884
885
886
887
888
889

    for key, value in dict1.items():
        merged_dict[key].extend(value)

    for key, value in dict2.items():
        merged_dict[key].extend(value)

    return dict(merged_dict)
890
891


892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
                 Tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: Dict[str, JSONTree[T]],
) -> Dict[str, JSONTree[U]]:
    ...


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: List[JSONTree[T]],
) -> List[JSONTree[U]]:
    ...


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: Tuple[JSONTree[T], ...],
) -> Tuple[JSONTree[U], ...]:
    ...


@overload
def json_map_leaves(
    func: Callable[[T], U],
    value: JSONTree[T],
) -> JSONTree[U]:
    ...


def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
    if isinstance(value, dict):
        return {k: json_map_leaves(func, v) for k, v in value.items()}
    elif isinstance(value, list):
        return [json_map_leaves(func, v) for v in value]
    elif isinstance(value, tuple):
        return tuple(json_map_leaves(func, v) for v in value)
    else:
        return func(value)


940
941
942
943
944
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
    """Flatten a list of lists to a single list."""
    return [item for sublist in lists for item in sublist]


945
def init_cached_hf_modules() -> None:
946
947
948
949
950
    """
    Lazy initialization of the Hugging Face modules.
    """
    from transformers.dynamic_module_utils import init_hf_modules
    init_hf_modules()
951
952


953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
@lru_cache(maxsize=None)
def find_library(lib_name: str) -> str:
    """
    Find the library file in the system.
    `lib_name` is full filename, with both prefix and suffix.
    This function resolves `lib_name` to the full path of the library.
    """
    # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
    # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
    # `/sbin/ldconfig` should exist in all Linux systems.
    # `/sbin/ldconfig` searches the library in the system
    libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
    # each line looks like the following:
    # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
    locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
    # `LD_LIBRARY_PATH` searches the library in the user-defined paths
969
    env_ld_library_path = envs.LD_LIBRARY_PATH
970
971
972
973
974
975
976
977
978
979
980
    if not locs and env_ld_library_path:
        locs = [
            os.path.join(dir, lib_name)
            for dir in env_ld_library_path.split(":")
            if os.path.exists(os.path.join(dir, lib_name))
        ]
    if not locs:
        raise ValueError(f"Cannot find {lib_name} in the system.")
    return locs[0]


981
def find_nccl_library() -> str:
982
983
984
985
986
987
    """
    We either use the library file specified by the `VLLM_NCCL_SO_PATH`
    environment variable, or we find the library file brought by PyTorch.
    After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
    found by `ctypes` automatically.
    """
988
    so_file = envs.VLLM_NCCL_SO_PATH
989
990
991
992

    # manually load the nccl library
    if so_file:
        logger.info(
993
994
            "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s",
            so_file)
995
996
    else:
        if torch.version.cuda is not None:
997
            so_file = "libnccl.so.2"
998
        elif torch.version.hip is not None:
999
            so_file = "librccl.so.1"
1000
1001
        else:
            raise ValueError("NCCL only supports CUDA and ROCm backends.")
1002
        logger.info("Found nccl from library %s", so_file)
1003
    return so_file
1004
1005
1006
1007
1008
1009
1010


def enable_trace_function_call_for_thread() -> None:
    """Set up function tracing for the current thread,
    if enabled via the VLLM_TRACE_FUNCTION environment variable
    """

1011
    if envs.VLLM_TRACE_FUNCTION:
1012
1013
1014
1015
1016
1017
1018
1019
        tmp_dir = tempfile.gettempdir()
        filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                    f"_thread_{threading.get_ident()}_"
                    f"at_{datetime.datetime.now()}.log").replace(" ", "_")
        log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
                                filename)
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        enable_trace_function_call(log_path)
1020
1021


1022
# `functools` helpers
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
def identity(value: T) -> T:
    return value


F = TypeVar('F', bound=Callable[..., Any])


def deprecate_kwargs(
        *kws: str,
        is_deprecated: Union[bool, Callable[[], bool]] = True,
        additional_message: Optional[str] = None) -> Callable[[F], F]:
    deprecated_kws = set(kws)

    if not callable(is_deprecated):
        is_deprecated = partial(identity, is_deprecated)

    def wrapper(fn: F) -> F:

        @wraps(fn)
        def inner(*args, **kwargs):
            if is_deprecated():
                deprecated_kwargs = kwargs.keys() & deprecated_kws
                if deprecated_kwargs:
                    msg = (
                        f"The keyword arguments {deprecated_kwargs} are "
                        "deprecated and will be removed in a future update.")
                    if additional_message is not None:
                        msg += f" {additional_message}"

                    warnings.warn(
                        DeprecationWarning(msg),
                        stacklevel=3,  # The inner function takes up one level
                    )

            return fn(*args, **kwargs)

        return inner  # type: ignore

    return wrapper
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078


@lru_cache(maxsize=8)
def _cuda_device_count_stateless(
        cuda_visible_devices: Optional[str] = None) -> int:
    # Note: cuda_visible_devices is not used, but we keep it as an argument for
    # LRU Cache purposes.

    # Code below is based on
    # https://github.com/pytorch/pytorch/blob/
    # c1cd946818442aca8c7f812b16d187ce1586c3bc/
    # torch/cuda/__init__.py#L831C1-L831C17
    import torch.cuda
    import torch.version

    if not torch.cuda._is_compiled():
        return 0
1079
1080
1081
1082
1083
1084
1085
1086
    if is_hip():
        # ROCm uses amdsmi instead of nvml for stateless device count
        # This requires a sufficiently modern version of Torch 2.4.0
        raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
            torch.cuda, "_device_count_amdsmi")) else -1
    else:
        raw_count = torch.cuda._device_count_nvml()
    r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
    return r


def cuda_device_count_stateless() -> int:
    """Get number of CUDA devices, caching based on the value of
    CUDA_VISIBLE_DEVICES at the time of call.
    
    This should be used instead of torch.cuda.device_count()
    unless CUDA_VISIBLE_DEVICES has already been set to the desired
    value."""

    # This can be removed and simply replaced with torch.cuda.get_device_count
    # after https://github.com/pytorch/pytorch/pull/122815 is released.
    return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112


#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f):

    def wrapper(*args, **kwargs) -> Any:
        if not wrapper.has_run:  # type: ignore[attr-defined]
            wrapper.has_run = True  # type: ignore[attr-defined]
            return f(*args, **kwargs)

    wrapper.has_run = False  # type: ignore[attr-defined]
    return wrapper
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125


class FlexibleArgumentParser(argparse.ArgumentParser):
    """ArgumentParser that allows both underscore and dash in names."""

    def parse_args(self, args=None, namespace=None):
        if args is None:
            args = sys.argv[1:]

        # Convert underscores to dashes and vice versa in argument names
        processed_args = []
        for arg in args:
            if arg.startswith('--'):
1126
1127
1128
1129
1130
1131
1132
                if '=' in arg:
                    key, value = arg.split('=', 1)
                    key = '--' + key[len('--'):].replace('_', '-')
                    processed_args.append(f'{key}={value}')
                else:
                    processed_args.append('--' +
                                          arg[len('--'):].replace('_', '-'))
1133
1134
1135
1136
            else:
                processed_args.append(arg)

        return super().parse_args(processed_args, namespace)
1137
1138
1139
1140
1141
1142
1143


async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
                              **kwargs):
    """Utility function to run async task in a lock"""
    async with lock:
        return await task(*args, **kwargs)