utils.py 26.7 KB
Newer Older
1
import argparse
2
import asyncio
3
import datetime
Woosuk Kwon's avatar
Woosuk Kwon committed
4
import enum
5
import gc
6
import os
7
import socket
8
import subprocess
9
import sys
10
11
import tempfile
import threading
Zhuohan Li's avatar
Zhuohan Li committed
12
import uuid
13
import warnings
14
from collections import defaultdict
15
from functools import lru_cache, partial, wraps
16
from platform import uname
17
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
18
                    Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
19
                    Union)
Zhuohan Li's avatar
Zhuohan Li committed
20

21
import numpy as np
22
import psutil
Zhuohan Li's avatar
Zhuohan Li committed
23
import torch
24
25
import torch.types
from typing_extensions import ParamSpec
26

27
import vllm.envs as envs
28
from vllm import _custom_ops as ops
29
from vllm.logger import enable_trace_function_call, init_logger
30
31
32
33
34
35
36

logger = init_logger(__name__)

STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
37
    "fp8": torch.uint8,
38
39
    "fp8_e4m3": torch.uint8,
    "fp8_e5m2": torch.uint8,
40
}
Zhuohan Li's avatar
Zhuohan Li committed
41

42
43
44
45
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")

Woosuk Kwon's avatar
Woosuk Kwon committed
46

47
48
49
50
51
52
53
class _Sentinel:
    ...


ALL_PINNED_SENTINEL = _Sentinel()


Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
58
59
60
61
62
63
class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


class Counter:

    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
64
    def __next__(self) -> int:
65
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
66
        self.counter += 1
67
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
70

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
71

72

73
class LRUCache(Generic[T]):
74
75

    def __init__(self, capacity: int):
76
        self.cache: OrderedDict[Hashable, T] = OrderedDict()
77
        self.pinned_items: Set[Hashable] = set()
78
79
80
81
82
83
84
85
        self.capacity = capacity

    def __contains__(self, key: Hashable) -> bool:
        return key in self.cache

    def __len__(self) -> int:
        return len(self.cache)

86
    def __getitem__(self, key: Hashable) -> Optional[T]:
87
88
        return self.get(key)

89
    def __setitem__(self, key: Hashable, value: T) -> None:
90
91
92
93
94
95
96
97
        self.put(key, value)

    def __delitem__(self, key: Hashable) -> None:
        self.pop(key)

    def touch(self, key: Hashable) -> None:
        self.cache.move_to_end(key)

98
99
100
    def get(self,
            key: Hashable,
            default_value: Optional[T] = None) -> Optional[T]:
101
        if key in self.cache:
102
            value: Optional[T] = self.cache[key]
103
104
105
106
107
            self.cache.move_to_end(key)
        else:
            value = default_value
        return value

108
    def put(self, key: Hashable, value: T) -> None:
109
110
111
112
        self.cache[key] = value
        self.cache.move_to_end(key)
        self._remove_old_if_needed()

113
114
115
116
117
118
119
120
121
122
123
124
    def pin(self, key: Hashable) -> None:
        """
        Pins a key in the cache preventing it from being
        evicted in the LRU order.
        """
        if key not in self.cache:
            raise ValueError(f"Cannot pin key: {key} not in cache.")
        self.pinned_items.add(key)

    def _unpin(self, key: Hashable) -> None:
        self.pinned_items.remove(key)

125
    def _on_remove(self, key: Hashable, value: Optional[T]):
126
127
        pass

128
    def remove_oldest(self, remove_pinned=False):
129
130
        if not self.cache:
            return
131
132
133
134
135
136
137
138
139
140
141
142

        if not remove_pinned:
            # pop the oldest item in the cache that is not pinned
            lru_key = next(
                (key for key in self.cache if key not in self.pinned_items),
                ALL_PINNED_SENTINEL)
            if lru_key is ALL_PINNED_SENTINEL:
                raise RuntimeError("All items are pinned, "
                                   "cannot remove oldest from the cache.")
        else:
            lru_key = next(iter(self.cache))
        self.pop(lru_key)
143
144
145
146
147

    def _remove_old_if_needed(self) -> None:
        while len(self.cache) > self.capacity:
            self.remove_oldest()

148
149
150
    def pop(self,
            key: Hashable,
            default_value: Optional[T] = None) -> Optional[T]:
151
        run_on_remove = key in self.cache
152
        value: Optional[T] = self.cache.pop(key, default_value)
153
154
155
        # remove from pinned items
        if key in self.pinned_items:
            self._unpin(key)
156
157
158
159
160
161
        if run_on_remove:
            self._on_remove(key, value)
        return value

    def clear(self):
        while len(self.cache) > 0:
162
            self.remove_oldest(remove_pinned=True)
163
164
165
        self.cache.clear()


166
167
168
169
def is_hip() -> bool:
    return torch.version.hip is not None


170
171
@lru_cache(maxsize=None)
def is_cpu() -> bool:
172
173
174
175
176
    from importlib.metadata import PackageNotFoundError, version
    try:
        return "cpu" in version("vllm")
    except PackageNotFoundError:
        return False
177
178


179
180
181
182
183
184
185
186
187
@lru_cache(maxsize=None)
def is_openvino() -> bool:
    from importlib.metadata import PackageNotFoundError, version
    try:
        return "openvino" in version("vllm")
    except PackageNotFoundError:
        return False


188
@lru_cache(maxsize=None)
189
190
191
192
193
194
195
196
def is_neuron() -> bool:
    try:
        import transformers_neuronx
    except ImportError:
        transformers_neuronx = None
    return transformers_neuronx is not None


197
198
199
200
201
202
203
204
205
@lru_cache(maxsize=None)
def is_tpu() -> bool:
    try:
        import libtpu
    except ImportError:
        libtpu = None
    return libtpu is not None


206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
@lru_cache(maxsize=None)
def is_xpu() -> bool:
    from importlib.metadata import version
    is_xpu_flag = "xpu" in version("vllm")
    # vllm is not build with xpu
    if not is_xpu_flag:
        return False
    try:
        import intel_extension_for_pytorch as ipex  # noqa: F401
        _import_ipex = True
    except ImportError as e:
        logger.warning("Import Error for IPEX: %s", e.msg)
        _import_ipex = False
    # ipex dependency is not ready
    if not _import_ipex:
        logger.warning("not found ipex lib")
        return False
    return hasattr(torch, "xpu") and torch.xpu.is_available()


226
@lru_cache(maxsize=None)
227
228
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
    """Returns the maximum shared memory per thread block in bytes."""
229
    max_shared_mem = (
230
        ops.get_max_shared_memory_per_block_device_attribute(gpu))
231
232
    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
    # will fail
233
    assert max_shared_mem > 0, "max_shared_mem can not be zero"
234
235
236
    return int(max_shared_mem)


237
def get_cpu_memory() -> int:
238
    """Returns the total CPU memory of the node in bytes."""
239
    return psutil.virtual_memory().total
Zhuohan Li's avatar
Zhuohan Li committed
240
241
242
243


def random_uuid() -> str:
    return str(uuid.uuid4().hex)
244

245

246
@lru_cache(maxsize=None)
247
def get_vllm_instance_id() -> str:
248
249
250
251
252
253
    """
    If the environment variable VLLM_INSTANCE_ID is set, return it.
    Otherwise, return a random UUID.
    Instance id represents an instance of the VLLM. All processes in the same
    instance should have the same instance id.
    """
254
    return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}"
255
256


257
@lru_cache(maxsize=None)
258
259
260
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()
261
262


263
def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
264
265
266
267
268
269
270
    """Take a blocking function, and run it on in an executor thread.

    This function prevents the blocking function from blocking the
    asyncio event loop.
    The code in this function needs to be thread safe.
    """

271
    def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
272
273
274
275
276
277
278
        loop = asyncio.get_event_loop()
        p_func = partial(func, *args, **kwargs)
        return loop.run_in_executor(executor=None, func=p_func)

    return _async_wrapper


279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def merge_async_iterators(
        *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
    """Merge multiple asynchronous iterators into a single iterator.

    This method handle the case where some iterators finish before others.
    When it yields, it yields a tuple (i, item) where i is the index of the
    iterator that yields the item.
    """
    queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()

    finished = [False] * len(iterators)

    async def producer(i: int, iterator: AsyncIterator[T]):
        try:
            async for item in iterator:
                await queue.put((i, item))
        except Exception as e:
            await queue.put(e)
        finished[i] = True

    _tasks = [
        asyncio.create_task(producer(i, iterator))
        for i, iterator in enumerate(iterators)
    ]

    async def consumer():
305
306
307
308
309
310
311
312
        try:
            while not all(finished) or not queue.empty():
                item = await queue.get()
                if isinstance(item, Exception):
                    raise item
                yield item
        except (Exception, asyncio.CancelledError) as e:
            for task in _tasks:
313
314
315
316
317
                if sys.version_info >= (3, 9):
                    # msg parameter only supported in Python 3.9+
                    task.cancel(e)
                else:
                    task.cancel()
318
            raise e
319
320
321
322
323
        await asyncio.gather(*_tasks)

    return consumer()


324
def get_ip() -> str:
325
    host_ip = envs.VLLM_HOST_IP
326
327
328
329
330
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

331
    # try ipv4
332
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
333
    try:
334
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
335
        return s.getsockname()[0]
336
337
338
339
340
    except Exception:
        pass

    # try ipv6
    try:
341
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
342
343
344
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
345
        return s.getsockname()[0]
346
347
348
349
350
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
351
352
        "The value can be set by the environment variable"
        " VLLM_HOST_IP or HOST_IP.",
353
354
        stacklevel=2)
    return "0.0.0.0"
355
356


357
def get_distributed_init_method(ip: str, port: int) -> str:
358
359
360
    # Brackets are not permitted in ipv4 addresses,
    # see https://github.com/python/cpython/issues/103848
    return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
361
362


363
def get_open_port() -> int:
364
365
    port = envs.VLLM_PORT
    if port is not None:
366
367
368
369
370
371
372
373
374
        while True:
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("", port))
                    return port
            except OSError:
                port += 1  # Increment port number if already in use
                logger.info("Port %d is already in use, trying port %d",
                            port - 1, port)
375
376
377
378
379
380
381
382
383
384
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
385
386


387
def update_environment_variables(envs: Dict[str, str]):
388
389
390
391
    if is_hip() and "CUDA_VISIBLE_DEVICES" in envs:
        # Propagate changes to CUDA_VISIBLE_DEVICES to
        # ROCm's HIP_VISIBLE_DEVICES as well
        envs["HIP_VISIBLE_DEVICES"] = envs["CUDA_VISIBLE_DEVICES"]
392
    for k, v in envs.items():
393
        if k in os.environ and os.environ[k] != v:
394
395
396
            logger.warning(
                "Overwriting environment variable %s "
                "from '%s' to '%s'", k, os.environ[k], v)
397
        os.environ[k] = v
398
399


400
def chunk_list(lst: List[T], chunk_size: int) -> List[List[T]]:
401
402
403
404
405
406
407
408
409
    """Yield successive chunk_size chunks from lst."""
    return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]


def cdiv(a: int, b: int) -> int:
    """Ceiling division."""
    return -(a // -b)


410
def _generate_random_fp8(
411
    tensor: torch.Tensor,
412
413
414
415
416
417
    low: float,
    high: float,
) -> None:
    # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
    # it may occur Inf or NaN if we directly use torch.randint
    # to generate random data for fp8 data.
418
    # For example, s.11111.00 in fp8e5m2 format represents Inf.
419
420
421
422
    #     | E4M3        | E5M2
    #-----|-------------|-------------------
    # Inf | N/A         | s.11111.00
    # NaN | s.1111.111  | s.11111.{01,10,11}
423
    from vllm import _custom_ops as ops
424
425
    tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
    tensor_tmp.uniform_(low, high)
426
    ops.convert_fp8(tensor, tensor_tmp)
427
428
429
    del tensor_tmp


430
431
432
def get_kv_cache_torch_dtype(
        cache_dtype: Optional[Union[str, torch.dtype]],
        model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
433
434
435
436
437
438
439
440
441
442
    if isinstance(cache_dtype, str):
        if cache_dtype == "auto":
            if isinstance(model_dtype, str):
                torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
            elif isinstance(model_dtype, torch.dtype):
                torch_dtype = model_dtype
            else:
                raise ValueError(f"Invalid model dtype: {model_dtype}")
        elif cache_dtype in ["half", "bfloat16", "float"]:
            torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
443
        elif cache_dtype == "fp8":
444
445
446
447
448
449
450
            torch_dtype = torch.uint8
        else:
            raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
    elif isinstance(cache_dtype, torch.dtype):
        torch_dtype = cache_dtype
    else:
        raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    return torch_dtype


def create_kv_caches_with_random_flash(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
    seed: int = 0,
    device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    assert cache_dtype != "fp8"
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
    key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
    scale = head_size**-0.5
473
474
475
476

    key_caches: List[torch.Tensor] = []
    value_caches: List[torch.Tensor] = []

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
    for _ in range(num_layers):
        key_value_cache = torch.empty(size=key_value_cache_shape,
                                      dtype=torch_dtype,
                                      device=device)
        key_value_cache.uniform_(-scale, scale)
        key_caches.append(key_value_cache[:, 0])
        value_caches.append(key_value_cache[:, 1])
    return key_caches, value_caches


def create_kv_caches_with_random(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
    seed: int = 0,
    device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
503
504
505
506

    scale = head_size**-0.5
    x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
507
    key_caches: List[torch.Tensor] = []
508
509
510
511
    for _ in range(num_layers):
        key_cache = torch.empty(size=key_cache_shape,
                                dtype=torch_dtype,
                                device=device)
512
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
513
            key_cache.uniform_(-scale, scale)
514
515
        elif cache_dtype == 'fp8':
            _generate_random_fp8(key_cache, -scale, scale)
516
517
518
        else:
            raise ValueError(
                f"Does not support key cache of type {cache_dtype}")
519
520
521
        key_caches.append(key_cache)

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
522
    value_caches: List[torch.Tensor] = []
523
524
525
526
    for _ in range(num_layers):
        value_cache = torch.empty(size=value_cache_shape,
                                  dtype=torch_dtype,
                                  device=device)
527
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
528
            value_cache.uniform_(-scale, scale)
529
530
        elif cache_dtype == 'fp8':
            _generate_random_fp8(value_cache, -scale, scale)
531
532
533
        else:
            raise ValueError(
                f"Does not support value cache of type {cache_dtype}")
534
535
        value_caches.append(value_cache)
    return key_caches, value_caches
536
537


538
539
540
541
542
543
544
545
546
547
548
549
550
551
@lru_cache
def print_warning_once(msg: str) -> None:
    logger.warning(msg)


@lru_cache(maxsize=None)
def is_pin_memory_available() -> bool:

    if in_wsl():
        # Pinning memory in WSL is not supported.
        # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
        print_warning_once("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
        return False
552
553
554
    elif is_xpu():
        print_warning_once("Pin memory is not supported on XPU.")
        return False
555
556
557
    elif is_neuron():
        print_warning_once("Pin memory is not supported on Neuron.")
        return False
558
    elif is_cpu() or is_openvino():
559
        return False
560
561
562
563
    return True


class CudaMemoryProfiler:
564

565
    def __init__(self, device: Optional[torch.types.Device] = None):
566
567
568
569
        self.device = device

    def current_memory_usage(self) -> float:
        # Return the memory usage in bytes.
570
571
572
573
574
575
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats(self.device)
            mem = torch.cuda.max_memory_allocated(self.device)
        elif is_xpu():
            torch.xpu.reset_peak_memory_stats(self.device)
            mem = torch.xpu.max_memory_allocated(self.device)
576
577
578
579
580
581
582
583
584
585
586
587
588
        return mem

    def __enter__(self):
        self.initial_memory = self.current_memory_usage()
        # This allows us to call methods of the context manager if needed
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.final_memory = self.current_memory_usage()
        self.consumed_memory = self.final_memory - self.initial_memory

        # Force garbage collection
        gc.collect()
589
590


591
def str_to_int_tuple(s: str) -> Tuple[int, ...]:
592
593
594
595
596
597
598
599
600
    """Convert a string to a tuple of integers."""
    try:
        return tuple(map(int, s.split(",")))
    except ValueError as e:
        raise ValueError(
            "String must be a series of integers separated by commas "
            f"(e.g., 1, 2, 3). Given input: {s}") from e


601
602
603
604
605
606
607
608
609
610
611
612
def make_tensor_with_pad(
    x: List[List[int]],
    max_len: int,
    pad: int,
    dtype: torch.dtype,
    device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
    """Make a padded tensor of a 2D inputs.

    The padding is applied to the end of each inner list until it reaches
    `max_len`.
    """
613
614
615
616
    padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad
    for ind, blocktb in enumerate(x):
        assert len(blocktb) <= max_len
        padded_x[ind, :len(blocktb)] = blocktb
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
    return torch.tensor(padded_x, dtype=dtype, device=device)


def async_tensor_h2d(
    data: list,
    dtype: torch.dtype,
    target_device: Union[str, torch.device],
    pin_memory: bool,
) -> torch.Tensor:
    """Asynchronously create a tensor and copy it from host to device."""
    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
    return t.to(device=target_device, non_blocking=True)


def maybe_expand_dim(tensor: torch.Tensor,
                     target_dims: int,
                     size: int = 1) -> torch.Tensor:
    """Expand the tensor to the target_dims."""
    if tensor.ndim < target_dims:
        tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
    return tensor
638
639


640
641
642
643
644
def get_dtype_size(dtype: torch.dtype) -> int:
    """Get the size of the data type in bytes."""
    return torch.tensor([], dtype=dtype).element_size()


645
646
def merge_dicts(dict1: Dict[K, List[T]],
                dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
647
    """Merge 2 dicts that have key -> List of items.
648

649
650
    When a key conflicts, the values in dict1 is prioritized.
    """
651
    merged_dict: Dict[K, List[T]] = defaultdict(list)
652
653
654
655
656
657
658
659

    for key, value in dict1.items():
        merged_dict[key].extend(value)

    for key, value in dict2.items():
        merged_dict[key].extend(value)

    return dict(merged_dict)
660
661


662
def init_cached_hf_modules() -> None:
663
664
665
666
667
    """
    Lazy initialization of the Hugging Face modules.
    """
    from transformers.dynamic_module_utils import init_hf_modules
    init_hf_modules()
668
669


670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
@lru_cache(maxsize=None)
def find_library(lib_name: str) -> str:
    """
    Find the library file in the system.
    `lib_name` is full filename, with both prefix and suffix.
    This function resolves `lib_name` to the full path of the library.
    """
    # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
    # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
    # `/sbin/ldconfig` should exist in all Linux systems.
    # `/sbin/ldconfig` searches the library in the system
    libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
    # each line looks like the following:
    # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
    locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
    # `LD_LIBRARY_PATH` searches the library in the user-defined paths
686
    env_ld_library_path = envs.LD_LIBRARY_PATH
687
688
689
690
691
692
693
694
695
696
697
    if not locs and env_ld_library_path:
        locs = [
            os.path.join(dir, lib_name)
            for dir in env_ld_library_path.split(":")
            if os.path.exists(os.path.join(dir, lib_name))
        ]
    if not locs:
        raise ValueError(f"Cannot find {lib_name} in the system.")
    return locs[0]


698
def find_nccl_library() -> str:
699
700
701
702
703
704
    """
    We either use the library file specified by the `VLLM_NCCL_SO_PATH`
    environment variable, or we find the library file brought by PyTorch.
    After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
    found by `ctypes` automatically.
    """
705
    so_file = envs.VLLM_NCCL_SO_PATH
706
707
708
709

    # manually load the nccl library
    if so_file:
        logger.info(
710
711
            "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s",
            so_file)
712
713
    else:
        if torch.version.cuda is not None:
714
            so_file = "libnccl.so.2"
715
        elif torch.version.hip is not None:
716
            so_file = "librccl.so.1"
717
718
        else:
            raise ValueError("NCCL only supports CUDA and ROCm backends.")
719
        logger.info("Found nccl from library %s", so_file)
720
    return so_file
721
722
723
724
725
726
727


def enable_trace_function_call_for_thread() -> None:
    """Set up function tracing for the current thread,
    if enabled via the VLLM_TRACE_FUNCTION environment variable
    """

728
    if envs.VLLM_TRACE_FUNCTION:
729
730
731
732
733
734
735
736
        tmp_dir = tempfile.gettempdir()
        filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                    f"_thread_{threading.get_ident()}_"
                    f"at_{datetime.datetime.now()}.log").replace(" ", "_")
        log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
                                filename)
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        enable_trace_function_call(log_path)
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777


def identity(value: T) -> T:
    return value


F = TypeVar('F', bound=Callable[..., Any])


def deprecate_kwargs(
        *kws: str,
        is_deprecated: Union[bool, Callable[[], bool]] = True,
        additional_message: Optional[str] = None) -> Callable[[F], F]:
    deprecated_kws = set(kws)

    if not callable(is_deprecated):
        is_deprecated = partial(identity, is_deprecated)

    def wrapper(fn: F) -> F:

        @wraps(fn)
        def inner(*args, **kwargs):
            if is_deprecated():
                deprecated_kwargs = kwargs.keys() & deprecated_kws
                if deprecated_kwargs:
                    msg = (
                        f"The keyword arguments {deprecated_kwargs} are "
                        "deprecated and will be removed in a future update.")
                    if additional_message is not None:
                        msg += f" {additional_message}"

                    warnings.warn(
                        DeprecationWarning(msg),
                        stacklevel=3,  # The inner function takes up one level
                    )

            return fn(*args, **kwargs)

        return inner  # type: ignore

    return wrapper
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794


@lru_cache(maxsize=8)
def _cuda_device_count_stateless(
        cuda_visible_devices: Optional[str] = None) -> int:
    # Note: cuda_visible_devices is not used, but we keep it as an argument for
    # LRU Cache purposes.

    # Code below is based on
    # https://github.com/pytorch/pytorch/blob/
    # c1cd946818442aca8c7f812b16d187ce1586c3bc/
    # torch/cuda/__init__.py#L831C1-L831C17
    import torch.cuda
    import torch.version

    if not torch.cuda._is_compiled():
        return 0
795
796
797
798
799
800
801
802
    if is_hip():
        # ROCm uses amdsmi instead of nvml for stateless device count
        # This requires a sufficiently modern version of Torch 2.4.0
        raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
            torch.cuda, "_device_count_amdsmi")) else -1
    else:
        raw_count = torch.cuda._device_count_nvml()
    r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
803
804
805
806
807
808
809
810
811
812
813
814
815
816
    return r


def cuda_device_count_stateless() -> int:
    """Get number of CUDA devices, caching based on the value of
    CUDA_VISIBLE_DEVICES at the time of call.
    
    This should be used instead of torch.cuda.device_count()
    unless CUDA_VISIBLE_DEVICES has already been set to the desired
    value."""

    # This can be removed and simply replaced with torch.cuda.get_device_count
    # after https://github.com/pytorch/pytorch/pull/122815 is released.
    return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
817
818
819
820
821
822
823
824
825
826
827
828


#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f):

    def wrapper(*args, **kwargs) -> Any:
        if not wrapper.has_run:  # type: ignore[attr-defined]
            wrapper.has_run = True  # type: ignore[attr-defined]
            return f(*args, **kwargs)

    wrapper.has_run = False  # type: ignore[attr-defined]
    return wrapper
829
830
831
832
833
834
835
836
837
838
839
840
841


class FlexibleArgumentParser(argparse.ArgumentParser):
    """ArgumentParser that allows both underscore and dash in names."""

    def parse_args(self, args=None, namespace=None):
        if args is None:
            args = sys.argv[1:]

        # Convert underscores to dashes and vice versa in argument names
        processed_args = []
        for arg in args:
            if arg.startswith('--'):
842
843
844
845
846
847
848
                if '=' in arg:
                    key, value = arg.split('=', 1)
                    key = '--' + key[len('--'):].replace('_', '-')
                    processed_args.append(f'{key}={value}')
                else:
                    processed_args.append('--' +
                                          arg[len('--'):].replace('_', '-'))
849
850
851
852
            else:
                processed_args.append(arg)

        return super().parse_args(processed_args, namespace)