__init__.py 56.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import datetime
Woosuk Kwon's avatar
Woosuk Kwon committed
6
import enum
7
import getpass
8
import hashlib
9
import importlib
10
import inspect
11
import ipaddress
12
import json
13
import multiprocessing
14
import os
15
import pickle
16
import signal
17
import socket
18
import subprocess
19
import sys
20
import tempfile
21
import textwrap
22
import threading
23
import traceback
Zhuohan Li's avatar
Zhuohan Li committed
24
import uuid
25
import warnings
26
import weakref
27
28
29
30
31
32
33
34
from argparse import (
    Action,
    ArgumentDefaultsHelpFormatter,
    ArgumentParser,
    ArgumentTypeError,
    RawDescriptionHelpFormatter,
    _ArgumentGroup,
)
35
from collections import defaultdict
36
from collections.abc import (
37
    Callable,
38
39
40
    Iterator,
    Sequence,
)
41
from concurrent.futures.process import ProcessPoolExecutor
42
from functools import cache, partial, wraps
43
from pathlib import Path
44
from typing import TYPE_CHECKING, Any, TextIO, TypeVar
45
from urllib.parse import urlparse
46
from uuid import uuid4
Zhuohan Li's avatar
Zhuohan Li committed
47

48
import cbor2
49
import cloudpickle
50
import psutil
51
import regex as re
52
import setproctitle
Zhuohan Li's avatar
Zhuohan Li committed
53
import torch
54
import yaml
55
56
import zmq
import zmq.asyncio
57

58
import vllm.envs as envs
59
from vllm.logger import enable_trace_function_call, init_logger
60
from vllm.ray.lazy_utils import is_in_ray_actor
61

62
if TYPE_CHECKING:
63
64
    from argparse import Namespace

65
    from vllm.config import ModelConfig, VllmConfig
66
67
68
69
70
else:
    Namespace = object

    ModelConfig = object
    VllmConfig = object
71

72
73
logger = init_logger(__name__)

74
75
76
77
78
79
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Constants related to forcing the attention backend selection

# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"

# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

95

96
# ANSI color codes
97
98
CYAN = "\033[1;36m"
RESET = "\033[0;0m"
99

100

101
T = TypeVar("T")
102
U = TypeVar("U")
103

104

Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
107
108
109
class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


110
111
112
113
114
class LayerBlockType(enum.Enum):
    attention = "attention"
    mamba = "mamba"


Woosuk Kwon's avatar
Woosuk Kwon committed
115
116
117
118
class Counter:
    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
119
    def __next__(self) -> int:
120
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
121
        self.counter += 1
122
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
123
124
125

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
126

127

Zhuohan Li's avatar
Zhuohan Li committed
128
129
def random_uuid() -> str:
    return str(uuid.uuid4().hex)
130

131

132
def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]):
133
134
135
136
137
    for sock in sockets:
        if sock is not None:
            sock.close(linger=0)


138
def get_ip() -> str:
139
    host_ip = envs.VLLM_HOST_IP
140
141
142
143
    if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
        logger.warning(
            "The environment variable HOST_IP is deprecated and ignored, as"
            " it is often used by Docker and other software to"
144
            " interact with the container's network stack. Please "
145
            "use VLLM_HOST_IP instead to set the IP address for vLLM processes"
146
147
            " to communicate with each other."
        )
148
149
150
151
152
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

153
    # try ipv4
154
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
155
    try:
156
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
157
        return s.getsockname()[0]
158
159
160
161
162
    except Exception:
        pass

    # try ipv6
    try:
163
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
164
165
166
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
167
        return s.getsockname()[0]
168
169
170
171
172
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
173
174
        "The value can be set by the environment variable"
        " VLLM_HOST_IP or HOST_IP.",
175
176
        stacklevel=2,
    )
177
    return "0.0.0.0"
178
179


180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def test_loopback_bind(address, family):
    try:
        s = socket.socket(family, socket.SOCK_DGRAM)
        s.bind((address, 0))  # Port 0 = auto assign
        s.close()
        return True
    except OSError:
        return False


def get_loopback_ip() -> str:
    loopback_ip = envs.VLLM_LOOPBACK_IP
    if loopback_ip:
        return loopback_ip

    # VLLM_LOOPBACK_IP is not set, try to get it based on network interface

    if test_loopback_bind("127.0.0.1", socket.AF_INET):
        return "127.0.0.1"
    elif test_loopback_bind("::1", socket.AF_INET6):
        return "::1"
    else:
        raise RuntimeError(
            "Neither 127.0.0.1 nor ::1 are bound to a local interface. "
204
205
            "Set the VLLM_LOOPBACK_IP environment variable explicitly."
        )
206
207


208
209
210
211
212
213
214
215
def is_valid_ipv6_address(address: str) -> bool:
    try:
        ipaddress.IPv6Address(address)
        return True
    except ValueError:
        return False


216
def split_host_port(host_port: str) -> tuple[str, int]:
217
    # ipv6
218
219
    if host_port.startswith("["):
        host, port = host_port.rsplit("]", 1)
220
        host = host[1:]
221
        port = port.split(":")[1]
222
223
        return host, int(port)
    else:
224
        host, port = host_port.split(":")
225
226
227
228
229
230
231
232
233
234
        return host, int(port)


def join_host_port(host: str, port: int) -> str:
    if is_valid_ipv6_address(host):
        return f"[{host}]:{port}"
    else:
        return f"{host}:{port}"


235
def get_distributed_init_method(ip: str, port: int) -> str:
236
237
238
239
    return get_tcp_uri(ip, port)


def get_tcp_uri(ip: str, port: int) -> str:
240
241
242
243
    if is_valid_ipv6_address(ip):
        return f"tcp://[{ip}]:{port}"
    else:
        return f"tcp://{ip}:{port}"
244
245


246
247
248
249
250
def get_open_zmq_ipc_path() -> str:
    base_rpc_path = envs.VLLM_RPC_BASE_PATH
    return f"ipc://{base_rpc_path}/{uuid4()}"


251
252
253
254
def get_open_zmq_inproc_path() -> str:
    return f"inproc://{uuid4()}"


255
def get_open_port() -> int:
256
257
258
259
260
261
262
263
264
    """
    Get an open port for the vLLM process to listen on.
    An edge case to handle, is when we run data parallel,
    we need to avoid ports that are potentially used by
    the data parallel master process.
    Right now we reserve 10 ports for the data parallel master
    process. Currently it uses 2 ports.
    """
    if "VLLM_DP_MASTER_PORT" in os.environ:
265
266
        dp_master_port = envs.VLLM_DP_MASTER_PORT
        reserved_port_range = range(dp_master_port, dp_master_port + 10)
267
        while True:
268
269
270
            candidate_port = _get_open_port()
            if candidate_port not in reserved_port_range:
                return candidate_port
271
272
    return _get_open_port()

youkaichao's avatar
youkaichao committed
273

274
275
def get_open_ports_list(count: int = 5) -> list[int]:
    """Get a list of open ports."""
276
    ports = set[int]()
277
278
279
280
281
    while len(ports) < count:
        ports.add(get_open_port())
    return list(ports)


282
def _get_open_port() -> int:
283
    port = envs.VLLM_PORT
284
    if port is not None:
285
286
287
288
289
290
291
        while True:
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("", port))
                    return port
            except OSError:
                port += 1  # Increment port number if already in use
292
                logger.info("Port %d is already in use, trying port %d", port - 1, port)
293
294
295
296
297
298
299
300
301
302
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
303
304


305
def find_process_using_port(port: int) -> psutil.Process | None:
306
307
308
309
310
311
312
    # TODO: We can not check for running processes with network
    # port on macOS. Therefore, we can not have a full graceful shutdown
    # of vLLM. For now, let's not look for processes in this case.
    # Ref: https://www.florianreinhard.de/accessdenied-in-psutil/
    if sys.platform.startswith("darwin"):
        return None

313
    our_pid = os.getpid()
314
    for conn in psutil.net_connections():
315
        if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid):
316
317
318
319
320
321
322
            try:
                return psutil.Process(conn.pid)
            except psutil.NoSuchProcess:
                return None
    return None


323
def update_environment_variables(envs: dict[str, str]):
324
    for k, v in envs.items():
325
        if k in os.environ and os.environ[k] != v:
326
            logger.warning(
327
328
329
330
331
                "Overwriting environment variable %s from '%s' to '%s'",
                k,
                os.environ[k],
                v,
            )
332
        os.environ[k] = v
333
334


335
336
337
338
339
def cdiv(a: int, b: int) -> int:
    """Ceiling division."""
    return -(a // -b)


340
341
342
343
344
345
346
def next_power_of_2(n) -> int:
    """The next power of 2 (inclusive)"""
    if n < 1:
        return 1
    return 1 << (n - 1).bit_length()


347
348
349
350
351
352
353
def prev_power_of_2(n: int) -> int:
    """The previous power of 2 (inclusive)"""
    if n <= 0:
        return 0
    return 1 << (n.bit_length() - 1)


354
355
356
357
def round_up(x: int, y: int) -> int:
    return ((x + y - 1) // y) * y


358
359
360
361
def round_down(x: int, y: int) -> int:
    return (x // y) * y


362
@cache
363
def is_pin_memory_available() -> bool:
364
    from vllm.platforms import current_platform
365

366
    return current_platform.is_pin_memory_available()
367
368


369
370
371
372
373
374
375
376
@cache
def is_uva_available() -> bool:
    """Check if Unified Virtual Addressing (UVA) is available."""
    # UVA requires pinned memory.
    # TODO: Add more requirements for UVA if needed.
    return is_pin_memory_available()


377
378
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
379
def init_cached_hf_modules() -> None:
380
381
382
383
    """
    Lazy initialization of the Hugging Face modules.
    """
    from transformers.dynamic_module_utils import init_hf_modules
384

385
    init_hf_modules()
386
387


388
@cache
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def find_library(lib_name: str) -> str:
    """
    Find the library file in the system.
    `lib_name` is full filename, with both prefix and suffix.
    This function resolves `lib_name` to the full path of the library.
    """
    # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
    # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
    # `/sbin/ldconfig` should exist in all Linux systems.
    # `/sbin/ldconfig` searches the library in the system
    libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
    # each line looks like the following:
    # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
    locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
    # `LD_LIBRARY_PATH` searches the library in the user-defined paths
404
    env_ld_library_path = envs.LD_LIBRARY_PATH
405
406
407
408
409
410
411
412
413
414
415
    if not locs and env_ld_library_path:
        locs = [
            os.path.join(dir, lib_name)
            for dir in env_ld_library_path.split(":")
            if os.path.exists(os.path.join(dir, lib_name))
        ]
    if not locs:
        raise ValueError(f"Cannot find {lib_name} in the system.")
    return locs[0]


416
def find_nccl_library() -> str:
417
418
419
420
421
422
    """
    We either use the library file specified by the `VLLM_NCCL_SO_PATH`
    environment variable, or we find the library file brought by PyTorch.
    After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
    found by `ctypes` automatically.
    """
423
    so_file = envs.VLLM_NCCL_SO_PATH
424
425
426
427

    # manually load the nccl library
    if so_file:
        logger.info(
428
429
            "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file
        )
430
431
    else:
        if torch.version.cuda is not None:
432
            so_file = "libnccl.so.2"
433
        elif torch.version.hip is not None:
434
            so_file = "librccl.so.1"
435
436
        else:
            raise ValueError("NCCL only supports CUDA and ROCm backends.")
437
        logger.debug_once("Found nccl from library %s", so_file)
438
    return so_file
439
440


441
def find_nccl_include_paths() -> list[str] | None:
442
443
    """
    We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
444
445
    environment variable, or we find the library file brought by
    nvidia-nccl-cuXX. load_inline by default uses
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    torch.utils.cpp_extension.include_paths
    """
    paths: list[str] = []
    inc = envs.VLLM_NCCL_INCLUDE_PATH
    if inc and os.path.isdir(inc):
        paths.append(inc)

    try:
        spec = importlib.util.find_spec("nvidia.nccl")
        if spec and getattr(spec, "submodule_search_locations", None):
            for loc in spec.submodule_search_locations:
                inc_dir = os.path.join(loc, "include")
                if os.path.exists(os.path.join(inc_dir, "nccl.h")):
                    paths.append(inc_dir)
    except Exception:
        pass

    seen = set()
    out: list[str] = []
    for p in paths:
        if p and p not in seen:
            out.append(p)
            seen.add(p)
    return out or None


472
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
473
474
475
476
    """Set up function tracing for the current thread,
    if enabled via the VLLM_TRACE_FUNCTION environment variable
    """

477
    if envs.VLLM_TRACE_FUNCTION:
478
        tmp_dir = tempfile.gettempdir()
479
480
        # add username to tmp_dir to avoid permission issues
        tmp_dir = os.path.join(tmp_dir, getpass.getuser())
481
482
483
484
485
486
487
488
        filename = (
            f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
            f"_thread_{threading.get_ident()}_"
            f"at_{datetime.datetime.now()}.log"
        ).replace(" ", "_")
        log_path = os.path.join(
            tmp_dir, "vllm", f"vllm-instance-{vllm_config.instance_id}", filename
        )
489
490
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        enable_trace_function_call(log_path)
491
492


493
494
495
496
497
498
499
def cuda_is_initialized() -> bool:
    """Check if CUDA is initialized."""
    if not torch.cuda._is_compiled():
        return False
    return torch.cuda.is_initialized()


500
501
502
503
504
505
506
def xpu_is_initialized() -> bool:
    """Check if XPU is initialized."""
    if not torch.xpu._is_compiled():
        return False
    return torch.xpu.is_initialized()


507
508
509
def cuda_get_device_properties(
    device, names: Sequence[str], init_cuda=False
) -> tuple[Any, ...]:
510
511
512
513
514
515
516
517
518
    """Get specified CUDA device property values without initializing CUDA in
    the current process."""
    if init_cuda or cuda_is_initialized():
        props = torch.cuda.get_device_properties(device)
        return tuple(getattr(props, name) for name in names)

    # Run in subprocess to avoid initializing CUDA as a side effect.
    mp_ctx = multiprocessing.get_context("fork")
    with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
519
        return executor.submit(cuda_get_device_properties, device, names, True).result()
520
521


522
523
524
def weak_bind(
    bound_method: Callable[..., Any],
) -> Callable[..., None]:
525
526
527
528
529
530
531
532
533
534
535
536
537
    """Make an instance method that weakly references
    its associated instance and no-ops once that
    instance is collected."""
    ref = weakref.ref(bound_method.__self__)  # type: ignore[attr-defined]
    unbound = bound_method.__func__  # type: ignore[attr-defined]

    def weak_bound(*args, **kwargs) -> None:
        if inst := ref():
            unbound(inst, *args, **kwargs)

    return weak_bound


538
class StoreBoolean(Action):
539
540
541
542
543
544
    def __call__(self, parser, namespace, values, option_string=None):
        if values.lower() == "true":
            setattr(namespace, self.dest, True)
        elif values.lower() == "false":
            setattr(namespace, self.dest, False)
        else:
545
546
547
            raise ValueError(
                f"Invalid boolean value: {values}. Expected 'true' or 'false'."
            )
548
549


550
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
551
552
    """SortedHelpFormatter that sorts arguments by their option strings."""

553
554
555
556
557
558
559
    def _split_lines(self, text, width):
        """
        1. Sentences split across lines have their single newlines removed.
        2. Paragraphs and explicit newlines are split into separate lines.
        3. Each line is wrapped to the specified width (width of terminal).
        """
        # The patterns also include whitespace after the newline
560
561
        single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*")
        multiple_newlines = re.compile(r"\n{2,}\s*")
562
        text = single_newline.sub(" ", text)
563
564
565
        lines = re.split(multiple_newlines, text)
        return sum([textwrap.wrap(line, width) for line in lines], [])

566
567
    def add_arguments(self, actions):
        actions = sorted(actions, key=lambda x: x.option_strings)
568
        super().add_arguments(actions)
569
570


571
class FlexibleArgumentParser(ArgumentParser):
572
573
    """ArgumentParser that allows both underscore and dash in names."""

574
    _deprecated: set[Action] = set()
575
576
577
578
579
580
581
    _json_tip: str = (
        "When passing JSON CLI arguments, the following sets of arguments "
        "are equivalent:\n"
        '   --json-arg \'{"key1": "value1", "key2": {"key3": "value2"}}\'\n'
        "   --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n"
        "Additionally, list elements can be passed individually using +:\n"
        '   --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
582
583
        "   --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
    )
584
    _search_keyword: str | None = None
585

586
    def __init__(self, *args, **kwargs):
587
588
589
        # Set the default "formatter_class" to SortedHelpFormatter
        if "formatter_class" not in kwargs:
            kwargs["formatter_class"] = SortedHelpFormatter
590
591
        # Pop kwarg "add_json_tip" to control whether to add the JSON tip
        self.add_json_tip = kwargs.pop("add_json_tip", True)
592
593
        super().__init__(*args, **kwargs)

594
    if sys.version_info < (3, 13):
595
        # Enable the deprecated kwarg for Python 3.12 and below
596

597
        def parse_known_args(self, args=None, namespace=None):
598
599
600
601
            if args is not None and "--disable-log-requests" in args:
                # Special case warning because the warning below won't trigger
                # if –-disable-log-requests because its value is default.
                logger.warning_once(
602
603
                    "argument '--disable-log-requests' is deprecated and "
                    "replaced with '--enable-log-requests'. This will be "
604
605
                    "removed in v0.12.0."
                )
606
607
            namespace, args = super().parse_known_args(args, namespace)
            for action in FlexibleArgumentParser._deprecated:
608
609
610
611
                if (
                    hasattr(namespace, dest := action.dest)
                    and getattr(namespace, dest) != action.default
                ):
612
                    logger.warning_once("argument '%s' is deprecated", dest)
613
614
            return namespace, args

615
616
        def add_argument(self, *args, **kwargs):
            deprecated = kwargs.pop("deprecated", False)
617
            action = super().add_argument(*args, **kwargs)
618
619
            if deprecated:
                FlexibleArgumentParser._deprecated.add(action)
620
621
            return action

622
623
624
625
626
627
628
629
630
631
632
633
        class _FlexibleArgumentGroup(_ArgumentGroup):
            def add_argument(self, *args, **kwargs):
                deprecated = kwargs.pop("deprecated", False)
                action = super().add_argument(*args, **kwargs)
                if deprecated:
                    FlexibleArgumentParser._deprecated.add(action)
                return action

        def add_argument_group(self, *args, **kwargs):
            group = self._FlexibleArgumentGroup(self, *args, **kwargs)
            self._action_groups.append(group)
            return group
634

635
636
637
638
639
640
641
642
643
644
645
646
    def format_help(self):
        # Only use custom help formatting for bottom level parsers
        if self._subparsers is not None:
            return super().format_help()

        formatter = self._get_formatter()

        # Handle keyword search of the args
        if (search_keyword := self._search_keyword) is not None:
            # Normalise the search keyword
            search_keyword = search_keyword.lower().replace("_", "-")
            # Return full help if searching for 'all'
647
            if search_keyword == "all":
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
                self.epilog = self._json_tip
                return super().format_help()

            # Return group help if searching for a group title
            for group in self._action_groups:
                if group.title and group.title.lower() == search_keyword:
                    formatter.start_section(group.title)
                    formatter.add_text(group.description)
                    formatter.add_arguments(group._group_actions)
                    formatter.end_section()
                    formatter.add_text(self._json_tip)
                    return formatter.format_help()

            # Return matched args if searching for an arg name
            matched_actions = []
            for group in self._action_groups:
                for action in group._group_actions:
                    # search option name
666
667
668
                    if any(
                        search_keyword in opt.lower() for opt in action.option_strings
                    ):
669
670
                        matched_actions.append(action)
            if matched_actions:
671
                formatter.start_section(f"Arguments matching '{search_keyword}'")
672
673
674
675
676
677
678
679
680
                formatter.add_arguments(matched_actions)
                formatter.end_section()
                formatter.add_text(self._json_tip)
                return formatter.format_help()

            # No match found
            formatter.add_text(
                f"No group or arguments matching '{search_keyword}'.\n"
                "Use '--help' to see available groups or "
681
682
                "'--help=all' to see all available parameters."
            )
683
684
685
            return formatter.format_help()

        # usage
686
        formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups)
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707

        # description
        formatter.add_text(self.description)

        # positionals, optionals and user-defined groups
        formatter.start_section("Config Groups")
        config_groups = ""
        for group in self._action_groups:
            if not group._group_actions:
                continue
            title = group.title
            description = group.description or ""
            config_groups += f"{title: <24}{description}\n"
        formatter.add_text(config_groups)
        formatter.end_section()

        # epilog
        formatter.add_text(self.epilog)

        # determine help from format above
        return formatter.format_help()
708

709
710
711
712
713
    def parse_args(  # type: ignore[override]
        self,
        args: list[str] | None = None,
        namespace: Namespace | None = None,
    ):
714
715
716
        if args is None:
            args = sys.argv[1:]

717
718
        # Check for --model in command line arguments first
        if args and args[0] == "serve":
719
720
            try:
                model_idx = next(
721
722
723
724
                    i
                    for i, arg in enumerate(args)
                    if arg == "--model" or arg.startswith("--model=")
                )
725
                logger.warning(
726
727
                    "With `vllm serve`, you should provide the model as a "
                    "positional argument or in a config file instead of via "
728
                    "the `--model` option. "
729
730
                    "The `--model` option will be removed in v0.13."
                )
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752

                if args[model_idx] == "--model":
                    model_tag = args[model_idx + 1]
                    rest_start_idx = model_idx + 2
                else:
                    model_tag = args[model_idx].removeprefix("--model=")
                    rest_start_idx = model_idx + 1

                # Move <model> to the front, e,g:
                # [Before]
                # vllm serve -tp 2 --model <model> --enforce-eager --port 8001
                # [After]
                # vllm serve <model> -tp 2 --enforce-eager --port 8001
                args = [
                    "serve",
                    model_tag,
                    *args[1:model_idx],
                    *args[rest_start_idx:],
                ]
                print("args", args)
            except StopIteration:
                pass
753

754
        if "--config" in args:
755
            args = self._pull_args_from_config(args)
756

757
758
759
760
761
762
763
        def repl(match: re.Match) -> str:
            """Replaces underscores with dashes in the matched string."""
            return match.group(0).replace("_", "-")

        # Everything between the first -- and the first .
        pattern = re.compile(r"(?<=--)[^\.]*")

764
        # Convert underscores to dashes and vice versa in argument names
765
        processed_args = list[str]()
766
        for i, arg in enumerate(args):
767
            if arg.startswith("--help="):
768
                FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower()
769
                processed_args.append("--help")
770
771
772
            elif arg.startswith("--"):
                if "=" in arg:
                    key, value = arg.split("=", 1)
773
                    key = pattern.sub(repl, key, count=1)
774
                    processed_args.append(f"{key}={value}")
775
                else:
776
777
                    key = pattern.sub(repl, arg, count=1)
                    processed_args.append(key)
778
            elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
779
780
                # allow -O flag to be used without space, e.g. -O3 or -Odecode
                # -O.<...> handled later
781
782
783
                # also handle -O=<mode> here
                mode = arg[3:] if arg[2] == "=" else arg[2:]
                processed_args.append(f"-O.mode={mode}")
784
785
786
787
788
            elif (
                arg == "-O"
                and i + 1 < len(args)
                and args[i + 1] in {"0", "1", "2", "3"}
            ):
789
790
                # Convert -O <n> to -O.mode <n>
                processed_args.append("-O.mode")
791
792
793
            else:
                processed_args.append(arg)

794
        def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
795
796
797
798
799
800
801
802
803
804
            """Creates a nested dictionary from a list of keys and a value.

            For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
            `{"a": {"b": {"c": 1}}}`
            """
            nested_dict: Any = value
            for key in reversed(keys):
                nested_dict = {key: nested_dict}
            return nested_dict

805
806
807
        def recursive_dict_update(
            original: dict[str, Any],
            update: dict[str, Any],
808
809
810
811
812
        ) -> set[str]:
            """Recursively updates a dictionary with another dictionary.
            Returns a set of duplicate keys that were overwritten.
            """
            duplicates = set[str]()
813
814
            for k, v in update.items():
                if isinstance(v, dict) and isinstance(original.get(k), dict):
815
816
817
818
                    nested_duplicates = recursive_dict_update(original[k], v)
                    duplicates |= {f"{k}.{d}" for d in nested_duplicates}
                elif isinstance(v, list) and isinstance(original.get(k), list):
                    original[k] += v
819
                else:
820
821
                    if k in original:
                        duplicates.add(k)
822
                    original[k] = v
823
            return duplicates
824

825
826
        delete = set[int]()
        dict_args = defaultdict[str, dict[str, Any]](dict)
827
        duplicates = set[str]()
828
        for i, processed_arg in enumerate(processed_args):
829
830
831
832
            if i in delete:  # skip if value from previous arg
                continue

            if processed_arg.startswith("-") and "." in processed_arg:
833
                if "=" in processed_arg:
834
                    processed_arg, value_str = processed_arg.split("=", 1)
835
                    if "." not in processed_arg:
836
                        # False positive, '.' was only in the value
837
838
                        continue
                else:
839
                    value_str = processed_args[i + 1]
840
                    delete.add(i + 1)
841

842
843
844
845
                if processed_arg.endswith("+"):
                    processed_arg = processed_arg[:-1]
                    value_str = json.dumps(list(value_str.split(",")))

846
                key, *keys = processed_arg.split(".")
847
848
849
850
851
                try:
                    value = json.loads(value_str)
                except json.decoder.JSONDecodeError:
                    value = value_str

852
853
                # Merge all values with the same key into a single dict
                arg_dict = create_nested_dict(keys, value)
854
855
                arg_duplicates = recursive_dict_update(dict_args[key], arg_dict)
                duplicates |= {f"{key}.{d}" for d in arg_duplicates}
856
857
                delete.add(i)
        # Filter out the dict args we set to None
858
        processed_args = [a for i, a in enumerate(processed_args) if i not in delete]
859
860
861
        if duplicates:
            logger.warning("Found duplicate keys %s", ", ".join(duplicates))

862
863
864
865
866
        # Add the dict args back as if they were originally passed as JSON
        for dict_arg, dict_value in dict_args.items():
            processed_args.append(dict_arg)
            processed_args.append(json.dumps(dict_value))

867
        return super().parse_args(processed_args, namespace)
868

869
870
871
872
    def check_port(self, value):
        try:
            value = int(value)
        except ValueError:
873
            msg = "Port must be an integer"
874
            raise ArgumentTypeError(msg) from None
875
876

        if not (1024 <= value <= 65535):
877
            raise ArgumentTypeError("Port must be between 1024 and 65535")
878
879
880

        return value

881
    def _pull_args_from_config(self, args: list[str]) -> list[str]:
882
883
        """Method to pull arguments specified in the config file
        into the command-line args variable.
884
885

        The arguments in config file will be inserted between
886
        the argument list.
887

888
889
890
891
892
893
894
895
896
897
        example:
        ```yaml
            port: 12323
            tensor-parallel-size: 4
        ```
        ```python
        $: vllm {serve,chat,complete} "facebook/opt-12B" \
            --config config.yaml -tp 2
        $: args = [
            "serve,chat,complete",
898
899
            "facebook/opt-12B",
            '--config', 'config.yaml',
900
901
902
903
            '-tp', '2'
        ]
        $: args = [
            "serve,chat,complete",
904
905
906
            "facebook/opt-12B",
            '--port', '12323',
            '--tensor-parallel-size', '4',
907
908
909
910
911
            '-tp', '2'
            ]
        ```

        Please note how the config args are inserted after the sub command.
912
        this way the order of priorities is maintained when these are args
913
914
        parsed by super().
        """
915
        assert args.count("--config") <= 1, "More than one config file specified!"
916

917
        index = args.index("--config")
918
        if index == len(args) - 1:
919
920
921
922
            raise ValueError(
                "No config file specified! \
                             Please check your command-line arguments."
            )
923
924
925

        file_path = args[index + 1]

926
        config_args = self.load_config_file(file_path)
927

928
        # 0th index might be the sub command {serve,chat,complete,...}
929
        # optionally followed by model_tag (only for serve)
930
931
932
933
        # followed by config args
        # followed by rest of cli args.
        # maintaining this order will enforce the precedence
        # of cli > config > defaults
934
        if args[0].startswith("-"):
935
            # No sub command (e.g., api_server entry point)
936
            args = config_args + args[0:index] + args[index + 2 :]
937
        elif args[0] == "serve":
938
939
            model_in_cli = len(args) > 1 and not args[1].startswith("-")
            model_in_config = any(arg == "--model" for arg in config_args)
940
941

            if not model_in_cli and not model_in_config:
942
                raise ValueError(
943
                    "No model specified! Please specify model either "
944
945
                    "as a positional argument or in a config file."
                )
946
947
948

            if model_in_cli:
                # Model specified as positional arg, keep CLI version
949
950
951
952
953
954
955
                args = (
                    [args[0]]
                    + [args[1]]
                    + config_args
                    + args[2:index]
                    + args[index + 2 :]
                )
956
957
            else:
                # No model in CLI, use config if available
958
                args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
959
        else:
960
            args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
961
962
963

        return args

964
    def load_config_file(self, file_path: str) -> list[str]:
965
        """Loads a yaml file and returns the key value pairs as a
966
967
968
969
970
971
972
973
974
975
976
        flattened list with argparse like pattern
        ```yaml
            port: 12323
            tensor-parallel-size: 4
        ```
        returns:
            processed_args: list[str] = [
                '--port': '12323',
                '--tensor-parallel-size': '4'
            ]
        """
977
978
        extension: str = file_path.split(".")[-1]
        if extension not in ("yaml", "yml"):
979
980
            raise ValueError(
                "Config file must be of a yaml/yml type.\
981
982
983
                              %s supplied",
                extension,
            )
984
985

        # only expecting a flat dictionary of atomic types
986
        processed_args: list[str] = []
987

988
        config: dict[str, int | str] = {}
989
        try:
990
            with open(file_path) as config_file:
991
992
993
994
                config = yaml.safe_load(config_file)
        except Exception as ex:
            logger.error(
                "Unable to read the config file at %s. \
995
996
997
                Make sure path is correct",
                file_path,
            )
998
999
            raise ex

1000
        store_boolean_arguments = [
1001
            action.dest for action in self._actions if isinstance(action, StoreBoolean)
1002
1003
        ]

1004
        for key, value in config.items():
1005
1006
            if isinstance(value, bool) and key not in store_boolean_arguments:
                if value:
1007
                    processed_args.append("--" + key)
1008
1009
            elif isinstance(value, list):
                if value:
1010
                    processed_args.append("--" + key)
1011
1012
                    for item in value:
                        processed_args.append(str(item))
1013
            else:
1014
                processed_args.append("--" + key)
1015
                processed_args.append(str(value))
1016
1017
1018

        return processed_args

1019

1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
class AtomicCounter:
    """An atomic, thread-safe counter"""

    def __init__(self, initial=0):
        """Initialize a new atomic counter to given initial value"""
        self._value = initial
        self._lock = threading.Lock()

    def inc(self, num=1):
        """Atomically increment the counter by num and return the new value"""
        with self._lock:
            self._value += num
            return self._value

    def dec(self, num=1):
        """Atomically decrement the counter by num and return the new value"""
        with self._lock:
            self._value -= num
            return self._value

    @property
    def value(self):
        return self._value
1043
1044


1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
def kill_process_tree(pid: int):
    """
    Kills all descendant processes of the given pid by sending SIGKILL.

    Args:
        pid (int): Process ID of the parent process
    """
    try:
        parent = psutil.Process(pid)
    except psutil.NoSuchProcess:
        return

    # Get all children recursively
    children = parent.children(recursive=True)

    # Send SIGKILL to all children first
    for child in children:
        with contextlib.suppress(ProcessLookupError):
            os.kill(child.pid, signal.SIGKILL)

    # Finally kill the parent
    with contextlib.suppress(ProcessLookupError):
        os.kill(pid, signal.SIGKILL)
1068
1069


1070
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
1071
def set_ulimit(target_soft_limit=65535):
1072
    if sys.platform.startswith("win"):
1073
1074
1075
1076
        logger.info("Windows detected, skipping ulimit adjustment.")
        return

    import resource
1077

1078
1079
1080
1081
1082
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
1083
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
1084
1085
        except ValueError as e:
            logger.warning(
1086
1087
                "Found ulimit of %s and failed to automatically increase "
                "with error %s. This can cause fd limit errors like "
1088
                "`OSError: [Errno 24] Too many open files`. Consider "
1089
1090
1091
1092
                "increasing with ulimit -n",
                current_soft,
                e,
            )
1093
1094
1095
1096
1097
1098
1099
1100
1101


# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
def get_exception_traceback():
    etype, value, tb = sys.exc_info()
    err_str = "".join(traceback.format_exception(etype, value, tb))
    return err_str


1102
def split_zmq_path(path: str) -> tuple[str, str, str]:
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
    """Split a zmq path into its parts."""
    parsed = urlparse(path)
    if not parsed.scheme:
        raise ValueError(f"Invalid zmq path: {path}")

    scheme = parsed.scheme
    host = parsed.hostname or ""
    port = str(parsed.port or "")

    if scheme == "tcp" and not all((host, port)):
        # The host and port fields are required for tcp
        raise ValueError(f"Invalid zmq path: {path}")

    if scheme != "tcp" and port:
        # port only makes sense with tcp
        raise ValueError(f"Invalid zmq path: {path}")

    return scheme, host, port


1123
def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str:
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
    """Make a ZMQ path from its parts.

    Args:
        scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc).
        host: The host - can be an IPv4 address, IPv6 address, or hostname.
        port: Optional port number, only used for TCP sockets.

    Returns:
        A properly formatted ZMQ path string.
    """
1134
    if port is None:
1135
1136
1137
1138
1139
1140
        return f"{scheme}://{host}"
    if is_valid_ipv6_address(host):
        return f"{scheme}://[{host}]:{port}"
    return f"{scheme}://{host}:{port}"


1141
1142
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
def make_zmq_socket(
1143
    ctx: zmq.asyncio.Context | zmq.Context,  # type: ignore[name-defined]
1144
    path: str,
1145
    socket_type: Any,
1146
1147
1148
    bind: bool | None = None,
    identity: bytes | None = None,
    linger: int | None = None,
1149
) -> zmq.Socket | zmq.asyncio.Socket:  # type: ignore[name-defined]
1150
1151
1152
    """Make a ZMQ socket with the proper bind/connect semantics."""

    mem = psutil.virtual_memory()
1153
    socket = ctx.socket(socket_type)
1154
1155
1156
1157
1158
1159
1160
1161

    # Calculate buffer size based on system memory
    total_mem = mem.total / 1024**3
    available_mem = mem.available / 1024**3
    # For systems with substantial memory (>32GB total, >16GB available):
    # - Set a large 0.5GB buffer to improve throughput
    # For systems with less memory:
    # - Use system default (-1) to avoid excessive memory consumption
1162
    buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1
1163

1164
    if bind is None:
1165
        bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177

    if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
        socket.setsockopt(zmq.RCVHWM, 0)
        socket.setsockopt(zmq.RCVBUF, buf_size)

    if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
        socket.setsockopt(zmq.SNDHWM, 0)
        socket.setsockopt(zmq.SNDBUF, buf_size)

    if identity is not None:
        socket.setsockopt(zmq.IDENTITY, identity)

1178
1179
1180
    if linger is not None:
        socket.setsockopt(zmq.LINGER, linger)

1181
1182
1183
    if socket_type == zmq.XPUB:
        socket.setsockopt(zmq.XPUB_VERBOSE, True)

1184
1185
1186
1187
1188
1189
    # Determine if the path is a TCP socket with an IPv6 address.
    # Enable IPv6 on the zmq socket if so.
    scheme, host, _ = split_zmq_path(path)
    if scheme == "tcp" and is_valid_ipv6_address(host):
        socket.setsockopt(zmq.IPV6, 1)

1190
    if bind:
1191
        socket.bind(path)
1192
    else:
1193
        socket.connect(path)
1194
1195
1196
1197
1198

    return socket


@contextlib.contextmanager
1199
1200
1201
def zmq_socket_ctx(
    path: str,
    socket_type: Any,
1202
    bind: bool | None = None,
1203
    linger: int = 0,
1204
    identity: bytes | None = None,
1205
) -> Iterator[zmq.Socket]:
1206
1207
    """Context manager for a ZMQ socket"""

1208
    ctx = zmq.Context()  # type: ignore[attr-defined]
1209
    try:
1210
        yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity)
1211
1212
1213
1214
    except KeyboardInterrupt:
        logger.debug("Got Keyboard Interrupt.")

    finally:
1215
        ctx.destroy(linger=linger)
1216
1217


1218
1219
1220
1221
1222
1223
1224
def _maybe_force_spawn():
    """Check if we need to force the use of the `spawn` multiprocessing start
    method.
    """
    if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
        return

1225
1226
    reasons = []
    if is_in_ray_actor():
1227
1228
1229
1230
        # even if we choose to spawn, we need to pass the ray address
        # to the subprocess so that it knows how to connect to the ray cluster.
        # env vars are inherited by subprocesses, even if we use spawn.
        import ray
1231

1232
        os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
1233
1234
1235
1236
1237
1238
        reasons.append("In a Ray actor and can only be spawned")

    if cuda_is_initialized():
        reasons.append("CUDA is initialized")
    elif xpu_is_initialized():
        reasons.append("XPU is initialized")
1239

1240
    if reasons:
1241
1242
1243
        logger.warning(
            "We must use the `spawn` multiprocessing start method. "
            "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
1244
            "See https://docs.vllm.ai/en/latest/usage/"
1245
            "troubleshooting.html#python-multiprocessing "
1246
1247
1248
            "for more information. Reasons: %s",
            "; ".join(reasons),
        )
1249
1250
1251
1252
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


def get_mp_context():
1253
1254
1255
1256
1257
1258
1259
    """Get a multiprocessing context with a particular method (spawn or fork).
    By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
    determine the multiprocessing method (default is fork). However, under
    certain conditions, we may enforce spawn and override the value of
    VLLM_WORKER_MULTIPROC_METHOD.
    """
    _maybe_force_spawn()
1260
1261
    mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
    return multiprocessing.get_context(mp_method)
1262
1263
1264


def bind_kv_cache(
1265
1266
    ctx: dict[str, Any],
    kv_cache: list[list[torch.Tensor]],  # [virtual_engine][layer_index]
1267
    shared_kv_cache_layers: dict[str, str] | None = None,
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
) -> None:
    # Bind the kv_cache tensor to Attention modules, similar to
    # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
    # Special things handled here:
    # 1. Some models have non-attention layers, e.g., Jamba
    # 2. Pipeline parallelism, each rank only has a subset of layers
    # 3. Encoder attention has no kv cache
    # 4. Encoder-decoder models, encoder-decoder attention and decoder-only
    #    attention of the same layer (e.g., bart's decoder.layers.1.self_attn
    #    and decoder.layers.1.encoder_attn) is mapped to the same kv cache
    #    tensor
1279
1280
1281
1282
    # 5. Some models have attention layers that share kv cache with previous
    #    layers, this is specified through shared_kv_cache_layers
    if shared_kv_cache_layers is None:
        shared_kv_cache_layers = {}
1283
1284
    from vllm.attention import AttentionType
    from vllm.model_executor.models.utils import extract_layer_index
1285

1286
    layer_need_kv_cache = [
1287
1288
1289
1290
1291
1292
1293
1294
        layer_name
        for layer_name in ctx
        if (
            hasattr(ctx[layer_name], "attn_type")
            and ctx[layer_name].attn_type
            in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)
        )
        and ctx[layer_name].kv_sharing_target_layer_name is None
1295
1296
    ]
    layer_index_sorted = sorted(
1297
1298
        set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)
    )
1299
    for layer_name in layer_need_kv_cache:
1300
        kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name))
1301
1302
1303
1304
        forward_ctx = ctx[layer_name]
        assert len(forward_ctx.kv_cache) == len(kv_cache)
        for ve, ve_kv_cache in enumerate(kv_cache):
            forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
1305
1306
    if shared_kv_cache_layers is not None:
        for layer_name, target_layer_name in shared_kv_cache_layers.items():
1307
1308
1309
            assert extract_layer_index(target_layer_name) < extract_layer_index(
                layer_name
            ), "v0 doesn't support interleaving kv sharing"
1310
            ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
1311
1312


1313
1314
def run_method(
    obj: Any,
1315
    method: str | bytes | Callable,
1316
1317
1318
    args: tuple[Any],
    kwargs: dict[str, Any],
) -> Any:
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
1332
1333
1334
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
1335
1336
1337
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358


def import_pynvml():
    """
    Historical comments:

    libnvml.so is the library behind nvidia-smi, and
    pynvml is a Python wrapper around it. We use it to get GPU
    status without initializing CUDA context in the current process.
    Historically, there are two packages that provide pynvml:
    - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
        wrapper. It is a dependency of vLLM, and is installed when users
        install vLLM. It provides a Python module named `pynvml`.
    - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
        Prior to version 12.0, it also provides a Python module `pynvml`,
        and therefore conflicts with the official one. What's worse,
        the module is a Python package, and has higher priority than
        the official one which is a standalone Python file.
        This causes errors when both of them are installed.
        Starting from version 12.0, it migrates to a new module
        named `pynvml_utils` to avoid the conflict.
1359
1360
1361
1362
1363
1364
1365
    It is so confusing that many packages in the community use the
    unofficial one by mistake, and we have to handle this case.
    For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
    one, and it will cause errors, see the issue
    https://github.com/vllm-project/vllm/issues/12847 for example.
    After all the troubles, we decide to copy the official `pynvml`
    module to our codebase, and use it directly.
1366
    """
1367
    import vllm.third_party.pynvml as pynvml
1368

1369
    return pynvml
1370
1371


1372
def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
    """
    A replacement for `abc.ABC`.
    When we use `abc.ABC`, subclasses will fail to instantiate
    if they do not implement all abstract methods.
    Here, we only require `raise NotImplementedError` in the
    base class, and log a warning if the method is not implemented
    in the subclass.
    """

    original_init = cls.__init__

    def find_unimplemented_methods(self: object):
        unimplemented_methods = []
        for attr_name in dir(self):
            # bypass inner method
1388
            if attr_name.startswith("_"):
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
                continue

            try:
                attr = getattr(self, attr_name)
                # get the func of callable method
                if callable(attr):
                    attr_func = attr.__func__
            except AttributeError:
                continue
            src = inspect.getsource(attr_func)
            if "NotImplementedError" in src:
                unimplemented_methods.append(attr_name)
        if unimplemented_methods:
1402
1403
            method_names = ",".join(unimplemented_methods)
            msg = f"Methods {method_names} not implemented in {self}"
1404
            logger.debug(msg)
1405
1406
1407
1408
1409
1410

    @wraps(original_init)
    def wrapped_init(self, *args, **kwargs) -> None:
        original_init(self, *args, **kwargs)
        find_unimplemented_methods(self)

1411
    type.__setattr__(cls, "__init__", wrapped_init)
1412
    return cls
1413
1414


1415
@contextlib.contextmanager
1416
def cprofile_context(save_file: str | None = None):
1417
1418
1419
1420
    """Run a cprofile

    Args:
        save_file: path to save the profile result. "1" or
1421
            None will result in printing to stdout.
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
    """
    import cProfile

    prof = cProfile.Profile()
    prof.enable()

    try:
        yield
    finally:
        prof.disable()
        if save_file and save_file != "1":
            prof.dump_stats(save_file)
        else:
            prof.print_stats(sort="cumtime")


1438
def cprofile(save_file: str | None = None, enabled: bool = True):
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
    """Decorator to profile a Python method using cProfile.

    Args:
        save_file: Path to save the profile result.
            If "1", None, or "", results will be printed to stdout.
        enabled: Set to false to turn this into a no-op
    """

    def decorator(func: Callable):
        @wraps(func)
        def wrapper(*args, **kwargs):
            if not enabled:
                # If profiling is disabled, just call the function directly.
                return func(*args, **kwargs)

            with cprofile_context(save_file):
                return func(*args, **kwargs)

        return wrapper

    return decorator
1460
1461


1462
1463
# Only relevant for models using ALiBi (e.g, MPT)
def check_use_alibi(model_config: ModelConfig) -> bool:
1464
    cfg = model_config.hf_text_config
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
    return (
        getattr(cfg, "alibi", False)  # Falcon
        or (
            "BloomForCausalLM" in getattr(model_config.hf_config, "architectures", [])
        )  # Bloom
        or getattr(cfg, "position_encoding_type", "") == "alibi"  # codellm_1b_alibi
        or (
            hasattr(cfg, "attn_config")  # MPT
            and (
                (
                    isinstance(cfg.attn_config, dict)
                    and cfg.attn_config.get("alibi", False)
                )
                or (
                    not isinstance(cfg.attn_config, dict)
                    and getattr(cfg.attn_config, "alibi", False)
                )
            )
        )
    )
1485
1486


1487
def sha256(input: Any) -> bytes:
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
    """Hash any picklable Python object using SHA-256.

    The input is serialized using pickle before hashing, which allows
    arbitrary Python objects to be used. Note that this function does
    not use a hash seed—if you need one, prepend it explicitly to the input.

    Args:
        input: Any picklable Python object.

    Returns:
1498
        Bytes representing the SHA-256 hash of the serialized input.
1499
1500
    """
    input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
1501
    return hashlib.sha256(input_bytes).digest()
1502
1503


1504
def sha256_cbor(input: Any) -> bytes:
1505
    """
1506
    Hash objects using CBOR serialization and SHA-256.
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516

    This option is useful for non-Python-dependent serialization and hashing.

    Args:
        input: Object to be serialized and hashed. Supported types include
            basic Python types and complex structures like lists, tuples, and
            dictionaries.
            Custom classes must implement CBOR serialization methods.

    Returns:
1517
        Bytes representing the SHA-256 hash of the CBOR serialized input.
1518
1519
    """
    input_bytes = cbor2.dumps(input, canonical=True)
1520
    return hashlib.sha256(input_bytes).digest()
1521
1522


1523
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
1524
1525
1526
1527
1528
1529
1530
1531
1532
    """Get a hash function by name, or raise an error if
    the function is not found.
    Args:
        hash_fn_name: Name of the hash function.
    Returns:
        A hash function.
    """
    if hash_fn_name == "sha256":
        return sha256
1533
1534
    if hash_fn_name == "sha256_cbor":
        return sha256_cbor
1535
1536
1537
1538

    raise ValueError(f"Unsupported hash function: {hash_fn_name}")


1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
@cache
def _has_module(module_name: str) -> bool:
    """Return True if *module_name* can be found in the current environment.

    The result is cached so that subsequent queries for the same module incur
    no additional overhead.
    """
    return importlib.util.find_spec(module_name) is not None


def has_pplx() -> bool:
    """Whether the optional `pplx_kernels` package is available."""

    return _has_module("pplx_kernels")


def has_deep_ep() -> bool:
    """Whether the optional `deep_ep` package is available."""

    return _has_module("deep_ep")


def has_deep_gemm() -> bool:
    """Whether the optional `deep_gemm` package is available."""

1564
    return _has_module("deep_gemm")
1565
1566


1567
1568
1569
1570
1571
1572
def has_triton_kernels() -> bool:
    """Whether the optional `triton_kernels` package is available."""

    return _has_module("triton_kernels")


1573
1574
1575
1576
1577
1578
def has_tilelang() -> bool:
    """Whether the optional `tilelang` package is available."""

    return _has_module("tilelang")


1579
1580
1581
def set_process_title(
    name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX
) -> None:
1582
1583
1584
    """
    Set the current process title to a specific name with an
    optional suffix.
1585
1586

    Args:
1587
        name: The title to assign to the current process.
1588
        suffix: An optional suffix to append to the base name.
1589
        prefix: A prefix to prepend to the front separated by `::`.
1590
1591
1592
    """
    if suffix:
        name = f"{name}_{suffix}"
1593
    setproctitle.setproctitle(f"{prefix}::{name}")
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607


def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
    """Prepend each output line with process-specific prefix"""

    prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
    file_write = file.write

    def write_with_prefix(s: str):
        if not s:
            return
        if file.start_new_line:  # type: ignore[attr-defined]
            file_write(prefix)
        idx = 0
1608
        while (next_idx := s.find("\n", idx)) != -1:
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
            next_idx += 1
            file_write(s[idx:next_idx])
            if next_idx == len(s):
                file.start_new_line = True  # type: ignore[attr-defined]
                return
            file_write(prefix)
            idx = next_idx
        file_write(s[idx:])
        file.start_new_line = False  # type: ignore[attr-defined]

    file.start_new_line = True  # type: ignore[attr-defined]
    file.write = write_with_prefix  # type: ignore[method-assign]


1623
def decorate_logs(process_name: str | None = None) -> None:
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
    """
    Adds a process-specific prefix to each line of output written to stdout and
    stderr.

    This function is intended to be called before initializing the api_server,
    engine_core, or worker classes, so that all subsequent output from the
    process is prefixed with the process name and PID. This helps distinguish
    log output from different processes in multi-process environments.

    Args:
        process_name: Optional; the name of the process to use in the prefix.
            If not provided, the current process name from the multiprocessing
            context is used.
    """
    if process_name is None:
        process_name = get_mp_context().current_process().name
    pid = os.getpid()
    _add_prefix(sys.stdout, process_name, pid)
    _add_prefix(sys.stderr, process_name, pid)
1643
1644
1645


def length_from_prompt_token_ids_or_embeds(
1646
1647
    prompt_token_ids: list[int] | None,
    prompt_embeds: torch.Tensor | None,
1648
) -> int:
1649
    """Calculate the request length (in number of tokens) give either
1650
1651
    prompt_token_ids or prompt_embeds.
    """
1652
1653
    prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids)
    prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds)
1654
1655
1656

    if prompt_token_len is None:
        if prompt_embeds_len is None:
1657
            raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.")
1658
1659
        return prompt_embeds_len
    else:
1660
        if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len:
1661
1662
1663
            raise ValueError(
                "Prompt token ids and prompt embeds had different lengths"
                f" prompt_token_ids={prompt_token_len}"
1664
1665
                f" prompt_embeds={prompt_embeds_len}"
            )
1666
        return prompt_token_len
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679


@contextlib.contextmanager
def set_env_var(key, value):
    old = os.environ.get(key)
    os.environ[key] = value
    try:
        yield
    finally:
        if old is None:
            del os.environ[key]
        else:
            os.environ[key] = old
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699


def unique_filepath(fn: Callable[[int], Path]) -> Path:
    """
    unique_filepath returns a unique path by trying
    to include an integer in increasing order.

    fn should be a callable that returns a path that
    includes the passed int at a fixed location.

    Note: This function has a TOCTOU race condition.
    Caller should use atomic operations (e.g., open with 'x' mode)
    when creating the file to ensure thread safety.
    """
    i = 0
    while True:
        p = fn(i)
        if not p.exists():
            return p
        i += 1