__init__.py 54.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import datetime
Woosuk Kwon's avatar
Woosuk Kwon committed
6
import enum
7
import getpass
8
import importlib
9
import inspect
10
import ipaddress
11
import json
12
import multiprocessing
13
import os
14
import signal
15
import socket
16
import subprocess
17
import sys
18
import tempfile
19
import textwrap
20
import threading
21
import traceback
Zhuohan Li's avatar
Zhuohan Li committed
22
import uuid
23
import warnings
24
import weakref
25
26
27
28
29
30
31
32
from argparse import (
    Action,
    ArgumentDefaultsHelpFormatter,
    ArgumentParser,
    ArgumentTypeError,
    RawDescriptionHelpFormatter,
    _ArgumentGroup,
)
33
from collections import defaultdict
34
from collections.abc import (
35
    Callable,
36
37
38
    Iterator,
    Sequence,
)
39
from concurrent.futures.process import ProcessPoolExecutor
40
from functools import cache, partial, wraps
41
from pathlib import Path
42
from typing import TYPE_CHECKING, Any, TextIO, TypeVar
43
from urllib.parse import urlparse
44
from uuid import uuid4
Zhuohan Li's avatar
Zhuohan Li committed
45

46
import cloudpickle
47
import psutil
48
import regex as re
49
import setproctitle
Zhuohan Li's avatar
Zhuohan Li committed
50
import torch
51
import yaml
52
53
import zmq
import zmq.asyncio
54

55
import vllm.envs as envs
56
from vllm.logger import enable_trace_function_call, init_logger
57
from vllm.ray.lazy_utils import is_in_ray_actor
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
_DEPRECATED_PROFILING = {"cprofile", "cprofile_context"}


def __getattr__(name: str) -> Any:  # noqa: D401 - short deprecation docstring
    """Module-level getattr to handle deprecated profiling utilities."""
    if name in _DEPRECATED_PROFILING:
        warnings.warn(
            f"vllm.utils.{name} is deprecated and will be removed in a future version. "
            f"Use vllm.utils.profiling.{name} instead.",
            DeprecationWarning,
            stacklevel=2,
        )
        import vllm.utils.profiling as _prof

        return getattr(_prof, name)
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__() -> list[str]:
    # expose deprecated names in dir() for better UX/tab-completion
    return sorted(list(globals().keys()) + list(_DEPRECATED_PROFILING))


82
if TYPE_CHECKING:
83
84
    from argparse import Namespace

85
    from vllm.config import ModelConfig, VllmConfig
86
87
88
89
90
else:
    Namespace = object

    ModelConfig = object
    VllmConfig = object
91

92
93
logger = init_logger(__name__)

94
95
96
97
98
99
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Constants related to forcing the attention backend selection

# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"

# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

115

116
# ANSI color codes
117
118
CYAN = "\033[1;36m"
RESET = "\033[0;0m"
119

120

121
T = TypeVar("T")
122
U = TypeVar("U")
123

124

Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
128
129
class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


130
131
132
133
134
class LayerBlockType(enum.Enum):
    attention = "attention"
    mamba = "mamba"


Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137
138
class Counter:
    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
139
    def __next__(self) -> int:
140
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
141
        self.counter += 1
142
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
145

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
146

147

Zhuohan Li's avatar
Zhuohan Li committed
148
149
def random_uuid() -> str:
    return str(uuid.uuid4().hex)
150

151

152
def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]):
153
154
155
156
157
    for sock in sockets:
        if sock is not None:
            sock.close(linger=0)


158
def get_ip() -> str:
159
    host_ip = envs.VLLM_HOST_IP
160
161
162
163
    if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
        logger.warning(
            "The environment variable HOST_IP is deprecated and ignored, as"
            " it is often used by Docker and other software to"
164
            " interact with the container's network stack. Please "
165
            "use VLLM_HOST_IP instead to set the IP address for vLLM processes"
166
167
            " to communicate with each other."
        )
168
169
170
171
172
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

173
    # try ipv4
174
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
175
    try:
176
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
177
        return s.getsockname()[0]
178
179
180
181
182
    except Exception:
        pass

    # try ipv6
    try:
183
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
184
185
186
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
187
        return s.getsockname()[0]
188
189
190
191
192
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
193
194
        "The value can be set by the environment variable"
        " VLLM_HOST_IP or HOST_IP.",
195
196
        stacklevel=2,
    )
197
    return "0.0.0.0"
198
199


200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def test_loopback_bind(address, family):
    try:
        s = socket.socket(family, socket.SOCK_DGRAM)
        s.bind((address, 0))  # Port 0 = auto assign
        s.close()
        return True
    except OSError:
        return False


def get_loopback_ip() -> str:
    loopback_ip = envs.VLLM_LOOPBACK_IP
    if loopback_ip:
        return loopback_ip

    # VLLM_LOOPBACK_IP is not set, try to get it based on network interface

    if test_loopback_bind("127.0.0.1", socket.AF_INET):
        return "127.0.0.1"
    elif test_loopback_bind("::1", socket.AF_INET6):
        return "::1"
    else:
        raise RuntimeError(
            "Neither 127.0.0.1 nor ::1 are bound to a local interface. "
224
225
            "Set the VLLM_LOOPBACK_IP environment variable explicitly."
        )
226
227


228
229
230
231
232
233
234
235
def is_valid_ipv6_address(address: str) -> bool:
    try:
        ipaddress.IPv6Address(address)
        return True
    except ValueError:
        return False


236
def split_host_port(host_port: str) -> tuple[str, int]:
237
    # ipv6
238
239
    if host_port.startswith("["):
        host, port = host_port.rsplit("]", 1)
240
        host = host[1:]
241
        port = port.split(":")[1]
242
243
        return host, int(port)
    else:
244
        host, port = host_port.split(":")
245
246
247
248
249
250
251
252
253
254
        return host, int(port)


def join_host_port(host: str, port: int) -> str:
    if is_valid_ipv6_address(host):
        return f"[{host}]:{port}"
    else:
        return f"{host}:{port}"


255
def get_distributed_init_method(ip: str, port: int) -> str:
256
257
258
259
    return get_tcp_uri(ip, port)


def get_tcp_uri(ip: str, port: int) -> str:
260
261
262
263
    if is_valid_ipv6_address(ip):
        return f"tcp://[{ip}]:{port}"
    else:
        return f"tcp://{ip}:{port}"
264
265


266
267
268
269
270
def get_open_zmq_ipc_path() -> str:
    base_rpc_path = envs.VLLM_RPC_BASE_PATH
    return f"ipc://{base_rpc_path}/{uuid4()}"


271
272
273
274
def get_open_zmq_inproc_path() -> str:
    return f"inproc://{uuid4()}"


275
def get_open_port() -> int:
276
277
278
279
280
281
282
283
284
    """
    Get an open port for the vLLM process to listen on.
    An edge case to handle, is when we run data parallel,
    we need to avoid ports that are potentially used by
    the data parallel master process.
    Right now we reserve 10 ports for the data parallel master
    process. Currently it uses 2 ports.
    """
    if "VLLM_DP_MASTER_PORT" in os.environ:
285
286
        dp_master_port = envs.VLLM_DP_MASTER_PORT
        reserved_port_range = range(dp_master_port, dp_master_port + 10)
287
        while True:
288
289
290
            candidate_port = _get_open_port()
            if candidate_port not in reserved_port_range:
                return candidate_port
291
292
    return _get_open_port()

youkaichao's avatar
youkaichao committed
293

294
295
def get_open_ports_list(count: int = 5) -> list[int]:
    """Get a list of open ports."""
296
    ports = set[int]()
297
298
299
300
301
    while len(ports) < count:
        ports.add(get_open_port())
    return list(ports)


302
def _get_open_port() -> int:
303
    port = envs.VLLM_PORT
304
    if port is not None:
305
306
307
308
309
310
311
        while True:
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("", port))
                    return port
            except OSError:
                port += 1  # Increment port number if already in use
312
                logger.info("Port %d is already in use, trying port %d", port - 1, port)
313
314
315
316
317
318
319
320
321
322
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
323
324


325
def find_process_using_port(port: int) -> psutil.Process | None:
326
327
328
329
330
331
332
    # TODO: We can not check for running processes with network
    # port on macOS. Therefore, we can not have a full graceful shutdown
    # of vLLM. For now, let's not look for processes in this case.
    # Ref: https://www.florianreinhard.de/accessdenied-in-psutil/
    if sys.platform.startswith("darwin"):
        return None

333
    our_pid = os.getpid()
334
    for conn in psutil.net_connections():
335
        if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid):
336
337
338
339
340
341
342
            try:
                return psutil.Process(conn.pid)
            except psutil.NoSuchProcess:
                return None
    return None


343
def update_environment_variables(envs: dict[str, str]):
344
    for k, v in envs.items():
345
        if k in os.environ and os.environ[k] != v:
346
            logger.warning(
347
348
349
350
351
                "Overwriting environment variable %s from '%s' to '%s'",
                k,
                os.environ[k],
                v,
            )
352
        os.environ[k] = v
353
354


355
356
357
358
359
def cdiv(a: int, b: int) -> int:
    """Ceiling division."""
    return -(a // -b)


360
361
362
363
364
365
366
def next_power_of_2(n) -> int:
    """The next power of 2 (inclusive)"""
    if n < 1:
        return 1
    return 1 << (n - 1).bit_length()


367
368
369
370
371
372
373
def prev_power_of_2(n: int) -> int:
    """The previous power of 2 (inclusive)"""
    if n <= 0:
        return 0
    return 1 << (n.bit_length() - 1)


374
375
376
377
def round_up(x: int, y: int) -> int:
    return ((x + y - 1) // y) * y


378
379
380
381
def round_down(x: int, y: int) -> int:
    return (x // y) * y


382
@cache
383
def is_pin_memory_available() -> bool:
384
    from vllm.platforms import current_platform
385

386
    return current_platform.is_pin_memory_available()
387
388


389
390
391
392
393
394
395
396
@cache
def is_uva_available() -> bool:
    """Check if Unified Virtual Addressing (UVA) is available."""
    # UVA requires pinned memory.
    # TODO: Add more requirements for UVA if needed.
    return is_pin_memory_available()


397
398
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
399
def init_cached_hf_modules() -> None:
400
401
402
403
    """
    Lazy initialization of the Hugging Face modules.
    """
    from transformers.dynamic_module_utils import init_hf_modules
404

405
    init_hf_modules()
406
407


408
@cache
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def find_library(lib_name: str) -> str:
    """
    Find the library file in the system.
    `lib_name` is full filename, with both prefix and suffix.
    This function resolves `lib_name` to the full path of the library.
    """
    # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
    # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
    # `/sbin/ldconfig` should exist in all Linux systems.
    # `/sbin/ldconfig` searches the library in the system
    libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
    # each line looks like the following:
    # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
    locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
    # `LD_LIBRARY_PATH` searches the library in the user-defined paths
424
    env_ld_library_path = envs.LD_LIBRARY_PATH
425
426
427
428
429
430
431
432
433
434
435
    if not locs and env_ld_library_path:
        locs = [
            os.path.join(dir, lib_name)
            for dir in env_ld_library_path.split(":")
            if os.path.exists(os.path.join(dir, lib_name))
        ]
    if not locs:
        raise ValueError(f"Cannot find {lib_name} in the system.")
    return locs[0]


436
def find_nccl_library() -> str:
437
438
439
440
441
442
    """
    We either use the library file specified by the `VLLM_NCCL_SO_PATH`
    environment variable, or we find the library file brought by PyTorch.
    After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
    found by `ctypes` automatically.
    """
443
    so_file = envs.VLLM_NCCL_SO_PATH
444
445
446
447

    # manually load the nccl library
    if so_file:
        logger.info(
448
449
            "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file
        )
450
451
    else:
        if torch.version.cuda is not None:
452
            so_file = "libnccl.so.2"
453
        elif torch.version.hip is not None:
454
            so_file = "librccl.so.1"
455
456
        else:
            raise ValueError("NCCL only supports CUDA and ROCm backends.")
457
        logger.debug_once("Found nccl from library %s", so_file)
458
    return so_file
459
460


461
def find_nccl_include_paths() -> list[str] | None:
462
463
    """
    We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
464
465
    environment variable, or we find the library file brought by
    nvidia-nccl-cuXX. load_inline by default uses
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    torch.utils.cpp_extension.include_paths
    """
    paths: list[str] = []
    inc = envs.VLLM_NCCL_INCLUDE_PATH
    if inc and os.path.isdir(inc):
        paths.append(inc)

    try:
        spec = importlib.util.find_spec("nvidia.nccl")
        if spec and getattr(spec, "submodule_search_locations", None):
            for loc in spec.submodule_search_locations:
                inc_dir = os.path.join(loc, "include")
                if os.path.exists(os.path.join(inc_dir, "nccl.h")):
                    paths.append(inc_dir)
    except Exception:
        pass

    seen = set()
    out: list[str] = []
    for p in paths:
        if p and p not in seen:
            out.append(p)
            seen.add(p)
    return out or None


492
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
493
494
495
496
    """Set up function tracing for the current thread,
    if enabled via the VLLM_TRACE_FUNCTION environment variable
    """

497
    if envs.VLLM_TRACE_FUNCTION:
498
        tmp_dir = tempfile.gettempdir()
499
500
        # add username to tmp_dir to avoid permission issues
        tmp_dir = os.path.join(tmp_dir, getpass.getuser())
501
502
503
504
505
506
507
508
        filename = (
            f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
            f"_thread_{threading.get_ident()}_"
            f"at_{datetime.datetime.now()}.log"
        ).replace(" ", "_")
        log_path = os.path.join(
            tmp_dir, "vllm", f"vllm-instance-{vllm_config.instance_id}", filename
        )
509
510
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        enable_trace_function_call(log_path)
511
512


513
514
515
516
517
518
519
def cuda_is_initialized() -> bool:
    """Check if CUDA is initialized."""
    if not torch.cuda._is_compiled():
        return False
    return torch.cuda.is_initialized()


520
521
522
523
524
525
526
def xpu_is_initialized() -> bool:
    """Check if XPU is initialized."""
    if not torch.xpu._is_compiled():
        return False
    return torch.xpu.is_initialized()


527
528
529
def cuda_get_device_properties(
    device, names: Sequence[str], init_cuda=False
) -> tuple[Any, ...]:
530
531
532
533
534
535
536
537
538
    """Get specified CUDA device property values without initializing CUDA in
    the current process."""
    if init_cuda or cuda_is_initialized():
        props = torch.cuda.get_device_properties(device)
        return tuple(getattr(props, name) for name in names)

    # Run in subprocess to avoid initializing CUDA as a side effect.
    mp_ctx = multiprocessing.get_context("fork")
    with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
539
        return executor.submit(cuda_get_device_properties, device, names, True).result()
540
541


542
543
544
def weak_bind(
    bound_method: Callable[..., Any],
) -> Callable[..., None]:
545
546
547
548
549
550
551
552
553
554
555
556
557
    """Make an instance method that weakly references
    its associated instance and no-ops once that
    instance is collected."""
    ref = weakref.ref(bound_method.__self__)  # type: ignore[attr-defined]
    unbound = bound_method.__func__  # type: ignore[attr-defined]

    def weak_bound(*args, **kwargs) -> None:
        if inst := ref():
            unbound(inst, *args, **kwargs)

    return weak_bound


558
class StoreBoolean(Action):
559
560
561
562
563
564
    def __call__(self, parser, namespace, values, option_string=None):
        if values.lower() == "true":
            setattr(namespace, self.dest, True)
        elif values.lower() == "false":
            setattr(namespace, self.dest, False)
        else:
565
566
567
            raise ValueError(
                f"Invalid boolean value: {values}. Expected 'true' or 'false'."
            )
568
569


570
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
571
572
    """SortedHelpFormatter that sorts arguments by their option strings."""

573
574
575
576
577
578
579
    def _split_lines(self, text, width):
        """
        1. Sentences split across lines have their single newlines removed.
        2. Paragraphs and explicit newlines are split into separate lines.
        3. Each line is wrapped to the specified width (width of terminal).
        """
        # The patterns also include whitespace after the newline
580
581
        single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*")
        multiple_newlines = re.compile(r"\n{2,}\s*")
582
        text = single_newline.sub(" ", text)
583
584
585
        lines = re.split(multiple_newlines, text)
        return sum([textwrap.wrap(line, width) for line in lines], [])

586
587
    def add_arguments(self, actions):
        actions = sorted(actions, key=lambda x: x.option_strings)
588
        super().add_arguments(actions)
589
590


591
class FlexibleArgumentParser(ArgumentParser):
592
593
    """ArgumentParser that allows both underscore and dash in names."""

594
    _deprecated: set[Action] = set()
595
596
597
598
599
600
601
    _json_tip: str = (
        "When passing JSON CLI arguments, the following sets of arguments "
        "are equivalent:\n"
        '   --json-arg \'{"key1": "value1", "key2": {"key3": "value2"}}\'\n'
        "   --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n"
        "Additionally, list elements can be passed individually using +:\n"
        '   --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
602
603
        "   --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
    )
604
    _search_keyword: str | None = None
605

606
    def __init__(self, *args, **kwargs):
607
608
609
        # Set the default "formatter_class" to SortedHelpFormatter
        if "formatter_class" not in kwargs:
            kwargs["formatter_class"] = SortedHelpFormatter
610
611
        # Pop kwarg "add_json_tip" to control whether to add the JSON tip
        self.add_json_tip = kwargs.pop("add_json_tip", True)
612
613
        super().__init__(*args, **kwargs)

614
    if sys.version_info < (3, 13):
615
        # Enable the deprecated kwarg for Python 3.12 and below
616

617
        def parse_known_args(self, args=None, namespace=None):
618
619
620
621
            if args is not None and "--disable-log-requests" in args:
                # Special case warning because the warning below won't trigger
                # if –-disable-log-requests because its value is default.
                logger.warning_once(
622
623
                    "argument '--disable-log-requests' is deprecated and "
                    "replaced with '--enable-log-requests'. This will be "
624
625
                    "removed in v0.12.0."
                )
626
627
            namespace, args = super().parse_known_args(args, namespace)
            for action in FlexibleArgumentParser._deprecated:
628
629
630
631
                if (
                    hasattr(namespace, dest := action.dest)
                    and getattr(namespace, dest) != action.default
                ):
632
                    logger.warning_once("argument '%s' is deprecated", dest)
633
634
            return namespace, args

635
636
        def add_argument(self, *args, **kwargs):
            deprecated = kwargs.pop("deprecated", False)
637
            action = super().add_argument(*args, **kwargs)
638
639
            if deprecated:
                FlexibleArgumentParser._deprecated.add(action)
640
641
            return action

642
643
644
645
646
647
648
649
650
651
652
653
        class _FlexibleArgumentGroup(_ArgumentGroup):
            def add_argument(self, *args, **kwargs):
                deprecated = kwargs.pop("deprecated", False)
                action = super().add_argument(*args, **kwargs)
                if deprecated:
                    FlexibleArgumentParser._deprecated.add(action)
                return action

        def add_argument_group(self, *args, **kwargs):
            group = self._FlexibleArgumentGroup(self, *args, **kwargs)
            self._action_groups.append(group)
            return group
654

655
656
657
658
659
660
661
662
663
664
665
666
    def format_help(self):
        # Only use custom help formatting for bottom level parsers
        if self._subparsers is not None:
            return super().format_help()

        formatter = self._get_formatter()

        # Handle keyword search of the args
        if (search_keyword := self._search_keyword) is not None:
            # Normalise the search keyword
            search_keyword = search_keyword.lower().replace("_", "-")
            # Return full help if searching for 'all'
667
            if search_keyword == "all":
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
                self.epilog = self._json_tip
                return super().format_help()

            # Return group help if searching for a group title
            for group in self._action_groups:
                if group.title and group.title.lower() == search_keyword:
                    formatter.start_section(group.title)
                    formatter.add_text(group.description)
                    formatter.add_arguments(group._group_actions)
                    formatter.end_section()
                    formatter.add_text(self._json_tip)
                    return formatter.format_help()

            # Return matched args if searching for an arg name
            matched_actions = []
            for group in self._action_groups:
                for action in group._group_actions:
                    # search option name
686
687
688
                    if any(
                        search_keyword in opt.lower() for opt in action.option_strings
                    ):
689
690
                        matched_actions.append(action)
            if matched_actions:
691
                formatter.start_section(f"Arguments matching '{search_keyword}'")
692
693
694
695
696
697
698
699
700
                formatter.add_arguments(matched_actions)
                formatter.end_section()
                formatter.add_text(self._json_tip)
                return formatter.format_help()

            # No match found
            formatter.add_text(
                f"No group or arguments matching '{search_keyword}'.\n"
                "Use '--help' to see available groups or "
701
702
                "'--help=all' to see all available parameters."
            )
703
704
705
            return formatter.format_help()

        # usage
706
        formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups)
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727

        # description
        formatter.add_text(self.description)

        # positionals, optionals and user-defined groups
        formatter.start_section("Config Groups")
        config_groups = ""
        for group in self._action_groups:
            if not group._group_actions:
                continue
            title = group.title
            description = group.description or ""
            config_groups += f"{title: <24}{description}\n"
        formatter.add_text(config_groups)
        formatter.end_section()

        # epilog
        formatter.add_text(self.epilog)

        # determine help from format above
        return formatter.format_help()
728

729
730
731
732
733
    def parse_args(  # type: ignore[override]
        self,
        args: list[str] | None = None,
        namespace: Namespace | None = None,
    ):
734
735
736
        if args is None:
            args = sys.argv[1:]

737
738
        # Check for --model in command line arguments first
        if args and args[0] == "serve":
739
740
            try:
                model_idx = next(
741
742
743
744
                    i
                    for i, arg in enumerate(args)
                    if arg == "--model" or arg.startswith("--model=")
                )
745
                logger.warning(
746
747
                    "With `vllm serve`, you should provide the model as a "
                    "positional argument or in a config file instead of via "
748
                    "the `--model` option. "
749
750
                    "The `--model` option will be removed in v0.13."
                )
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772

                if args[model_idx] == "--model":
                    model_tag = args[model_idx + 1]
                    rest_start_idx = model_idx + 2
                else:
                    model_tag = args[model_idx].removeprefix("--model=")
                    rest_start_idx = model_idx + 1

                # Move <model> to the front, e,g:
                # [Before]
                # vllm serve -tp 2 --model <model> --enforce-eager --port 8001
                # [After]
                # vllm serve <model> -tp 2 --enforce-eager --port 8001
                args = [
                    "serve",
                    model_tag,
                    *args[1:model_idx],
                    *args[rest_start_idx:],
                ]
                print("args", args)
            except StopIteration:
                pass
773

774
        if "--config" in args:
775
            args = self._pull_args_from_config(args)
776

777
778
779
780
781
782
783
        def repl(match: re.Match) -> str:
            """Replaces underscores with dashes in the matched string."""
            return match.group(0).replace("_", "-")

        # Everything between the first -- and the first .
        pattern = re.compile(r"(?<=--)[^\.]*")

784
        # Convert underscores to dashes and vice versa in argument names
785
        processed_args = list[str]()
786
        for i, arg in enumerate(args):
787
            if arg.startswith("--help="):
788
                FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower()
789
                processed_args.append("--help")
790
791
792
            elif arg.startswith("--"):
                if "=" in arg:
                    key, value = arg.split("=", 1)
793
                    key = pattern.sub(repl, key, count=1)
794
                    processed_args.append(f"{key}={value}")
795
                else:
796
797
                    key = pattern.sub(repl, arg, count=1)
                    processed_args.append(key)
798
            elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
799
800
                # allow -O flag to be used without space, e.g. -O3 or -Odecode
                # -O.<...> handled later
801
802
803
                # also handle -O=<mode> here
                mode = arg[3:] if arg[2] == "=" else arg[2:]
                processed_args.append(f"-O.mode={mode}")
804
805
806
807
808
            elif (
                arg == "-O"
                and i + 1 < len(args)
                and args[i + 1] in {"0", "1", "2", "3"}
            ):
809
810
                # Convert -O <n> to -O.mode <n>
                processed_args.append("-O.mode")
811
812
813
            else:
                processed_args.append(arg)

814
        def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
815
816
817
818
819
820
821
822
823
824
            """Creates a nested dictionary from a list of keys and a value.

            For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
            `{"a": {"b": {"c": 1}}}`
            """
            nested_dict: Any = value
            for key in reversed(keys):
                nested_dict = {key: nested_dict}
            return nested_dict

825
826
827
        def recursive_dict_update(
            original: dict[str, Any],
            update: dict[str, Any],
828
829
830
831
832
        ) -> set[str]:
            """Recursively updates a dictionary with another dictionary.
            Returns a set of duplicate keys that were overwritten.
            """
            duplicates = set[str]()
833
834
            for k, v in update.items():
                if isinstance(v, dict) and isinstance(original.get(k), dict):
835
836
837
838
                    nested_duplicates = recursive_dict_update(original[k], v)
                    duplicates |= {f"{k}.{d}" for d in nested_duplicates}
                elif isinstance(v, list) and isinstance(original.get(k), list):
                    original[k] += v
839
                else:
840
841
                    if k in original:
                        duplicates.add(k)
842
                    original[k] = v
843
            return duplicates
844

845
846
        delete = set[int]()
        dict_args = defaultdict[str, dict[str, Any]](dict)
847
        duplicates = set[str]()
848
        for i, processed_arg in enumerate(processed_args):
849
850
851
852
            if i in delete:  # skip if value from previous arg
                continue

            if processed_arg.startswith("-") and "." in processed_arg:
853
                if "=" in processed_arg:
854
                    processed_arg, value_str = processed_arg.split("=", 1)
855
                    if "." not in processed_arg:
856
                        # False positive, '.' was only in the value
857
858
                        continue
                else:
859
                    value_str = processed_args[i + 1]
860
                    delete.add(i + 1)
861

862
863
864
865
                if processed_arg.endswith("+"):
                    processed_arg = processed_arg[:-1]
                    value_str = json.dumps(list(value_str.split(",")))

866
                key, *keys = processed_arg.split(".")
867
868
869
870
871
                try:
                    value = json.loads(value_str)
                except json.decoder.JSONDecodeError:
                    value = value_str

872
873
                # Merge all values with the same key into a single dict
                arg_dict = create_nested_dict(keys, value)
874
875
                arg_duplicates = recursive_dict_update(dict_args[key], arg_dict)
                duplicates |= {f"{key}.{d}" for d in arg_duplicates}
876
877
                delete.add(i)
        # Filter out the dict args we set to None
878
        processed_args = [a for i, a in enumerate(processed_args) if i not in delete]
879
880
881
        if duplicates:
            logger.warning("Found duplicate keys %s", ", ".join(duplicates))

882
883
884
885
886
        # Add the dict args back as if they were originally passed as JSON
        for dict_arg, dict_value in dict_args.items():
            processed_args.append(dict_arg)
            processed_args.append(json.dumps(dict_value))

887
        return super().parse_args(processed_args, namespace)
888

889
890
891
892
    def check_port(self, value):
        try:
            value = int(value)
        except ValueError:
893
            msg = "Port must be an integer"
894
            raise ArgumentTypeError(msg) from None
895
896

        if not (1024 <= value <= 65535):
897
            raise ArgumentTypeError("Port must be between 1024 and 65535")
898
899
900

        return value

901
    def _pull_args_from_config(self, args: list[str]) -> list[str]:
902
903
        """Method to pull arguments specified in the config file
        into the command-line args variable.
904
905

        The arguments in config file will be inserted between
906
        the argument list.
907

908
909
910
911
912
913
914
915
916
917
        example:
        ```yaml
            port: 12323
            tensor-parallel-size: 4
        ```
        ```python
        $: vllm {serve,chat,complete} "facebook/opt-12B" \
            --config config.yaml -tp 2
        $: args = [
            "serve,chat,complete",
918
919
            "facebook/opt-12B",
            '--config', 'config.yaml',
920
921
922
923
            '-tp', '2'
        ]
        $: args = [
            "serve,chat,complete",
924
925
926
            "facebook/opt-12B",
            '--port', '12323',
            '--tensor-parallel-size', '4',
927
928
929
930
931
            '-tp', '2'
            ]
        ```

        Please note how the config args are inserted after the sub command.
932
        this way the order of priorities is maintained when these are args
933
934
        parsed by super().
        """
935
        assert args.count("--config") <= 1, "More than one config file specified!"
936

937
        index = args.index("--config")
938
        if index == len(args) - 1:
939
940
941
942
            raise ValueError(
                "No config file specified! \
                             Please check your command-line arguments."
            )
943
944
945

        file_path = args[index + 1]

946
        config_args = self.load_config_file(file_path)
947

948
        # 0th index might be the sub command {serve,chat,complete,...}
949
        # optionally followed by model_tag (only for serve)
950
951
952
953
        # followed by config args
        # followed by rest of cli args.
        # maintaining this order will enforce the precedence
        # of cli > config > defaults
954
        if args[0].startswith("-"):
955
            # No sub command (e.g., api_server entry point)
956
            args = config_args + args[0:index] + args[index + 2 :]
957
        elif args[0] == "serve":
958
959
            model_in_cli = len(args) > 1 and not args[1].startswith("-")
            model_in_config = any(arg == "--model" for arg in config_args)
960
961

            if not model_in_cli and not model_in_config:
962
                raise ValueError(
963
                    "No model specified! Please specify model either "
964
965
                    "as a positional argument or in a config file."
                )
966
967
968

            if model_in_cli:
                # Model specified as positional arg, keep CLI version
969
970
971
972
973
974
975
                args = (
                    [args[0]]
                    + [args[1]]
                    + config_args
                    + args[2:index]
                    + args[index + 2 :]
                )
976
977
            else:
                # No model in CLI, use config if available
978
                args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
979
        else:
980
            args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
981
982
983

        return args

984
    def load_config_file(self, file_path: str) -> list[str]:
985
        """Loads a yaml file and returns the key value pairs as a
986
987
988
989
990
991
992
993
994
995
996
        flattened list with argparse like pattern
        ```yaml
            port: 12323
            tensor-parallel-size: 4
        ```
        returns:
            processed_args: list[str] = [
                '--port': '12323',
                '--tensor-parallel-size': '4'
            ]
        """
997
998
        extension: str = file_path.split(".")[-1]
        if extension not in ("yaml", "yml"):
999
1000
            raise ValueError(
                "Config file must be of a yaml/yml type.\
1001
1002
1003
                              %s supplied",
                extension,
            )
1004
1005

        # only expecting a flat dictionary of atomic types
1006
        processed_args: list[str] = []
1007

1008
        config: dict[str, int | str] = {}
1009
        try:
1010
            with open(file_path) as config_file:
1011
1012
1013
1014
                config = yaml.safe_load(config_file)
        except Exception as ex:
            logger.error(
                "Unable to read the config file at %s. \
1015
1016
1017
                Make sure path is correct",
                file_path,
            )
1018
1019
            raise ex

1020
        store_boolean_arguments = [
1021
            action.dest for action in self._actions if isinstance(action, StoreBoolean)
1022
1023
        ]

1024
        for key, value in config.items():
1025
1026
            if isinstance(value, bool) and key not in store_boolean_arguments:
                if value:
1027
                    processed_args.append("--" + key)
1028
1029
            elif isinstance(value, list):
                if value:
1030
                    processed_args.append("--" + key)
1031
1032
                    for item in value:
                        processed_args.append(str(item))
1033
            else:
1034
                processed_args.append("--" + key)
1035
                processed_args.append(str(value))
1036
1037
1038

        return processed_args

1039

1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
class AtomicCounter:
    """An atomic, thread-safe counter"""

    def __init__(self, initial=0):
        """Initialize a new atomic counter to given initial value"""
        self._value = initial
        self._lock = threading.Lock()

    def inc(self, num=1):
        """Atomically increment the counter by num and return the new value"""
        with self._lock:
            self._value += num
            return self._value

    def dec(self, num=1):
        """Atomically decrement the counter by num and return the new value"""
        with self._lock:
            self._value -= num
            return self._value

    @property
    def value(self):
        return self._value
1063
1064


1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
def kill_process_tree(pid: int):
    """
    Kills all descendant processes of the given pid by sending SIGKILL.

    Args:
        pid (int): Process ID of the parent process
    """
    try:
        parent = psutil.Process(pid)
    except psutil.NoSuchProcess:
        return

    # Get all children recursively
    children = parent.children(recursive=True)

    # Send SIGKILL to all children first
    for child in children:
        with contextlib.suppress(ProcessLookupError):
            os.kill(child.pid, signal.SIGKILL)

    # Finally kill the parent
    with contextlib.suppress(ProcessLookupError):
        os.kill(pid, signal.SIGKILL)
1088
1089


1090
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
1091
def set_ulimit(target_soft_limit=65535):
1092
    if sys.platform.startswith("win"):
1093
1094
1095
1096
        logger.info("Windows detected, skipping ulimit adjustment.")
        return

    import resource
1097

1098
1099
1100
1101
1102
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
1103
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
1104
1105
        except ValueError as e:
            logger.warning(
1106
1107
                "Found ulimit of %s and failed to automatically increase "
                "with error %s. This can cause fd limit errors like "
1108
                "`OSError: [Errno 24] Too many open files`. Consider "
1109
1110
1111
1112
                "increasing with ulimit -n",
                current_soft,
                e,
            )
1113
1114
1115
1116
1117
1118
1119
1120
1121


# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
def get_exception_traceback():
    etype, value, tb = sys.exc_info()
    err_str = "".join(traceback.format_exception(etype, value, tb))
    return err_str


1122
def split_zmq_path(path: str) -> tuple[str, str, str]:
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
    """Split a zmq path into its parts."""
    parsed = urlparse(path)
    if not parsed.scheme:
        raise ValueError(f"Invalid zmq path: {path}")

    scheme = parsed.scheme
    host = parsed.hostname or ""
    port = str(parsed.port or "")

    if scheme == "tcp" and not all((host, port)):
        # The host and port fields are required for tcp
        raise ValueError(f"Invalid zmq path: {path}")

    if scheme != "tcp" and port:
        # port only makes sense with tcp
        raise ValueError(f"Invalid zmq path: {path}")

    return scheme, host, port


1143
def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str:
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
    """Make a ZMQ path from its parts.

    Args:
        scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc).
        host: The host - can be an IPv4 address, IPv6 address, or hostname.
        port: Optional port number, only used for TCP sockets.

    Returns:
        A properly formatted ZMQ path string.
    """
1154
    if port is None:
1155
1156
1157
1158
1159
1160
        return f"{scheme}://{host}"
    if is_valid_ipv6_address(host):
        return f"{scheme}://[{host}]:{port}"
    return f"{scheme}://{host}:{port}"


1161
1162
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
def make_zmq_socket(
1163
    ctx: zmq.asyncio.Context | zmq.Context,  # type: ignore[name-defined]
1164
    path: str,
1165
    socket_type: Any,
1166
1167
1168
    bind: bool | None = None,
    identity: bytes | None = None,
    linger: int | None = None,
1169
) -> zmq.Socket | zmq.asyncio.Socket:  # type: ignore[name-defined]
1170
1171
1172
    """Make a ZMQ socket with the proper bind/connect semantics."""

    mem = psutil.virtual_memory()
1173
    socket = ctx.socket(socket_type)
1174
1175
1176
1177
1178
1179
1180
1181

    # Calculate buffer size based on system memory
    total_mem = mem.total / 1024**3
    available_mem = mem.available / 1024**3
    # For systems with substantial memory (>32GB total, >16GB available):
    # - Set a large 0.5GB buffer to improve throughput
    # For systems with less memory:
    # - Use system default (-1) to avoid excessive memory consumption
1182
    buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1
1183

1184
    if bind is None:
1185
        bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197

    if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
        socket.setsockopt(zmq.RCVHWM, 0)
        socket.setsockopt(zmq.RCVBUF, buf_size)

    if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
        socket.setsockopt(zmq.SNDHWM, 0)
        socket.setsockopt(zmq.SNDBUF, buf_size)

    if identity is not None:
        socket.setsockopt(zmq.IDENTITY, identity)

1198
1199
1200
    if linger is not None:
        socket.setsockopt(zmq.LINGER, linger)

1201
1202
1203
    if socket_type == zmq.XPUB:
        socket.setsockopt(zmq.XPUB_VERBOSE, True)

1204
1205
1206
1207
1208
1209
    # Determine if the path is a TCP socket with an IPv6 address.
    # Enable IPv6 on the zmq socket if so.
    scheme, host, _ = split_zmq_path(path)
    if scheme == "tcp" and is_valid_ipv6_address(host):
        socket.setsockopt(zmq.IPV6, 1)

1210
    if bind:
1211
        socket.bind(path)
1212
    else:
1213
        socket.connect(path)
1214
1215
1216
1217
1218

    return socket


@contextlib.contextmanager
1219
1220
1221
def zmq_socket_ctx(
    path: str,
    socket_type: Any,
1222
    bind: bool | None = None,
1223
    linger: int = 0,
1224
    identity: bytes | None = None,
1225
) -> Iterator[zmq.Socket]:
1226
1227
    """Context manager for a ZMQ socket"""

1228
    ctx = zmq.Context()  # type: ignore[attr-defined]
1229
    try:
1230
        yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity)
1231
1232
1233
1234
    except KeyboardInterrupt:
        logger.debug("Got Keyboard Interrupt.")

    finally:
1235
        ctx.destroy(linger=linger)
1236
1237


1238
1239
1240
1241
1242
1243
1244
def _maybe_force_spawn():
    """Check if we need to force the use of the `spawn` multiprocessing start
    method.
    """
    if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
        return

1245
1246
    reasons = []
    if is_in_ray_actor():
1247
1248
1249
1250
        # even if we choose to spawn, we need to pass the ray address
        # to the subprocess so that it knows how to connect to the ray cluster.
        # env vars are inherited by subprocesses, even if we use spawn.
        import ray
1251

1252
        os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
1253
1254
1255
1256
1257
1258
        reasons.append("In a Ray actor and can only be spawned")

    if cuda_is_initialized():
        reasons.append("CUDA is initialized")
    elif xpu_is_initialized():
        reasons.append("XPU is initialized")
1259

1260
    if reasons:
1261
1262
1263
        logger.warning(
            "We must use the `spawn` multiprocessing start method. "
            "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
1264
            "See https://docs.vllm.ai/en/latest/usage/"
1265
            "troubleshooting.html#python-multiprocessing "
1266
1267
1268
            "for more information. Reasons: %s",
            "; ".join(reasons),
        )
1269
1270
1271
1272
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


def get_mp_context():
1273
1274
1275
1276
1277
1278
1279
    """Get a multiprocessing context with a particular method (spawn or fork).
    By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
    determine the multiprocessing method (default is fork). However, under
    certain conditions, we may enforce spawn and override the value of
    VLLM_WORKER_MULTIPROC_METHOD.
    """
    _maybe_force_spawn()
1280
1281
    mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
    return multiprocessing.get_context(mp_method)
1282
1283
1284


def bind_kv_cache(
1285
1286
    ctx: dict[str, Any],
    kv_cache: list[list[torch.Tensor]],  # [virtual_engine][layer_index]
1287
    shared_kv_cache_layers: dict[str, str] | None = None,
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
) -> None:
    # Bind the kv_cache tensor to Attention modules, similar to
    # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
    # Special things handled here:
    # 1. Some models have non-attention layers, e.g., Jamba
    # 2. Pipeline parallelism, each rank only has a subset of layers
    # 3. Encoder attention has no kv cache
    # 4. Encoder-decoder models, encoder-decoder attention and decoder-only
    #    attention of the same layer (e.g., bart's decoder.layers.1.self_attn
    #    and decoder.layers.1.encoder_attn) is mapped to the same kv cache
    #    tensor
1299
1300
1301
1302
    # 5. Some models have attention layers that share kv cache with previous
    #    layers, this is specified through shared_kv_cache_layers
    if shared_kv_cache_layers is None:
        shared_kv_cache_layers = {}
1303
1304
    from vllm.attention import AttentionType
    from vllm.model_executor.models.utils import extract_layer_index
1305

1306
    layer_need_kv_cache = [
1307
1308
1309
1310
1311
1312
1313
1314
        layer_name
        for layer_name in ctx
        if (
            hasattr(ctx[layer_name], "attn_type")
            and ctx[layer_name].attn_type
            in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)
        )
        and ctx[layer_name].kv_sharing_target_layer_name is None
1315
1316
    ]
    layer_index_sorted = sorted(
1317
1318
        set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)
    )
1319
    for layer_name in layer_need_kv_cache:
1320
        kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name))
1321
1322
1323
1324
        forward_ctx = ctx[layer_name]
        assert len(forward_ctx.kv_cache) == len(kv_cache)
        for ve, ve_kv_cache in enumerate(kv_cache):
            forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
1325
1326
    if shared_kv_cache_layers is not None:
        for layer_name, target_layer_name in shared_kv_cache_layers.items():
1327
1328
1329
            assert extract_layer_index(target_layer_name) < extract_layer_index(
                layer_name
            ), "v0 doesn't support interleaving kv sharing"
1330
            ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
1331
1332


1333
1334
def run_method(
    obj: Any,
1335
    method: str | bytes | Callable,
1336
1337
1338
    args: tuple[Any],
    kwargs: dict[str, Any],
) -> Any:
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
1352
1353
1354
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
1355
1356
1357
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378


def import_pynvml():
    """
    Historical comments:

    libnvml.so is the library behind nvidia-smi, and
    pynvml is a Python wrapper around it. We use it to get GPU
    status without initializing CUDA context in the current process.
    Historically, there are two packages that provide pynvml:
    - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
        wrapper. It is a dependency of vLLM, and is installed when users
        install vLLM. It provides a Python module named `pynvml`.
    - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
        Prior to version 12.0, it also provides a Python module `pynvml`,
        and therefore conflicts with the official one. What's worse,
        the module is a Python package, and has higher priority than
        the official one which is a standalone Python file.
        This causes errors when both of them are installed.
        Starting from version 12.0, it migrates to a new module
        named `pynvml_utils` to avoid the conflict.
1379
1380
1381
1382
1383
1384
1385
    It is so confusing that many packages in the community use the
    unofficial one by mistake, and we have to handle this case.
    For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
    one, and it will cause errors, see the issue
    https://github.com/vllm-project/vllm/issues/12847 for example.
    After all the troubles, we decide to copy the official `pynvml`
    module to our codebase, and use it directly.
1386
    """
1387
    import vllm.third_party.pynvml as pynvml
1388

1389
    return pynvml
1390
1391


1392
def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
    """
    A replacement for `abc.ABC`.
    When we use `abc.ABC`, subclasses will fail to instantiate
    if they do not implement all abstract methods.
    Here, we only require `raise NotImplementedError` in the
    base class, and log a warning if the method is not implemented
    in the subclass.
    """

    original_init = cls.__init__

    def find_unimplemented_methods(self: object):
        unimplemented_methods = []
        for attr_name in dir(self):
            # bypass inner method
1408
            if attr_name.startswith("_"):
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
                continue

            try:
                attr = getattr(self, attr_name)
                # get the func of callable method
                if callable(attr):
                    attr_func = attr.__func__
            except AttributeError:
                continue
            src = inspect.getsource(attr_func)
            if "NotImplementedError" in src:
                unimplemented_methods.append(attr_name)
        if unimplemented_methods:
1422
1423
            method_names = ",".join(unimplemented_methods)
            msg = f"Methods {method_names} not implemented in {self}"
1424
            logger.debug(msg)
1425
1426
1427
1428
1429
1430

    @wraps(original_init)
    def wrapped_init(self, *args, **kwargs) -> None:
        original_init(self, *args, **kwargs)
        find_unimplemented_methods(self)

1431
    type.__setattr__(cls, "__init__", wrapped_init)
1432
    return cls
1433
1434


1435
## moved to vllm.utils.profiling (imported at module top)
1436
1437


1438
1439
# Only relevant for models using ALiBi (e.g, MPT)
def check_use_alibi(model_config: ModelConfig) -> bool:
1440
    cfg = model_config.hf_text_config
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
    return (
        getattr(cfg, "alibi", False)  # Falcon
        or (
            "BloomForCausalLM" in getattr(model_config.hf_config, "architectures", [])
        )  # Bloom
        or getattr(cfg, "position_encoding_type", "") == "alibi"  # codellm_1b_alibi
        or (
            hasattr(cfg, "attn_config")  # MPT
            and (
                (
                    isinstance(cfg.attn_config, dict)
                    and cfg.attn_config.get("alibi", False)
                )
                or (
                    not isinstance(cfg.attn_config, dict)
                    and getattr(cfg.attn_config, "alibi", False)
                )
            )
        )
    )
1461
1462


1463
## moved to vllm.utils.hashing
1464
1465


1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
@cache
def _has_module(module_name: str) -> bool:
    """Return True if *module_name* can be found in the current environment.

    The result is cached so that subsequent queries for the same module incur
    no additional overhead.
    """
    return importlib.util.find_spec(module_name) is not None


def has_pplx() -> bool:
    """Whether the optional `pplx_kernels` package is available."""

    return _has_module("pplx_kernels")


def has_deep_ep() -> bool:
    """Whether the optional `deep_ep` package is available."""

    return _has_module("deep_ep")


def has_deep_gemm() -> bool:
    """Whether the optional `deep_gemm` package is available."""

1491
    return _has_module("deep_gemm")
1492
1493


1494
1495
1496
1497
1498
1499
def has_triton_kernels() -> bool:
    """Whether the optional `triton_kernels` package is available."""

    return _has_module("triton_kernels")


1500
1501
1502
1503
1504
1505
def has_tilelang() -> bool:
    """Whether the optional `tilelang` package is available."""

    return _has_module("tilelang")


1506
1507
1508
def set_process_title(
    name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX
) -> None:
1509
1510
1511
    """
    Set the current process title to a specific name with an
    optional suffix.
1512
1513

    Args:
1514
        name: The title to assign to the current process.
1515
        suffix: An optional suffix to append to the base name.
1516
        prefix: A prefix to prepend to the front separated by `::`.
1517
1518
1519
    """
    if suffix:
        name = f"{name}_{suffix}"
1520
    setproctitle.setproctitle(f"{prefix}::{name}")
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534


def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
    """Prepend each output line with process-specific prefix"""

    prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
    file_write = file.write

    def write_with_prefix(s: str):
        if not s:
            return
        if file.start_new_line:  # type: ignore[attr-defined]
            file_write(prefix)
        idx = 0
1535
        while (next_idx := s.find("\n", idx)) != -1:
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
            next_idx += 1
            file_write(s[idx:next_idx])
            if next_idx == len(s):
                file.start_new_line = True  # type: ignore[attr-defined]
                return
            file_write(prefix)
            idx = next_idx
        file_write(s[idx:])
        file.start_new_line = False  # type: ignore[attr-defined]

    file.start_new_line = True  # type: ignore[attr-defined]
    file.write = write_with_prefix  # type: ignore[method-assign]


1550
def decorate_logs(process_name: str | None = None) -> None:
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
    """
    Adds a process-specific prefix to each line of output written to stdout and
    stderr.

    This function is intended to be called before initializing the api_server,
    engine_core, or worker classes, so that all subsequent output from the
    process is prefixed with the process name and PID. This helps distinguish
    log output from different processes in multi-process environments.

    Args:
        process_name: Optional; the name of the process to use in the prefix.
            If not provided, the current process name from the multiprocessing
            context is used.
    """
    if process_name is None:
        process_name = get_mp_context().current_process().name
    pid = os.getpid()
    _add_prefix(sys.stdout, process_name, pid)
    _add_prefix(sys.stderr, process_name, pid)
1570
1571
1572


def length_from_prompt_token_ids_or_embeds(
1573
1574
    prompt_token_ids: list[int] | None,
    prompt_embeds: torch.Tensor | None,
1575
) -> int:
1576
    """Calculate the request length (in number of tokens) give either
1577
1578
    prompt_token_ids or prompt_embeds.
    """
1579
1580
    prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids)
    prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds)
1581
1582
1583

    if prompt_token_len is None:
        if prompt_embeds_len is None:
1584
            raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.")
1585
1586
        return prompt_embeds_len
    else:
1587
        if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len:
1588
1589
1590
            raise ValueError(
                "Prompt token ids and prompt embeds had different lengths"
                f" prompt_token_ids={prompt_token_len}"
1591
1592
                f" prompt_embeds={prompt_embeds_len}"
            )
1593
        return prompt_token_len
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606


@contextlib.contextmanager
def set_env_var(key, value):
    old = os.environ.get(key)
    os.environ[key] = value
    try:
        yield
    finally:
        if old is None:
            del os.environ[key]
        else:
            os.environ[key] = old
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626


def unique_filepath(fn: Callable[[int], Path]) -> Path:
    """
    unique_filepath returns a unique path by trying
    to include an integer in increasing order.

    fn should be a callable that returns a path that
    includes the passed int at a fixed location.

    Note: This function has a TOCTOU race condition.
    Caller should use atomic operations (e.g., open with 'x' mode)
    when creating the file to ensure thread safety.
    """
    i = 0
    while True:
        p = fn(i)
        if not p.exists():
            return p
        i += 1