__init__.py 17.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import datetime
Woosuk Kwon's avatar
Woosuk Kwon committed
6
import enum
7
import getpass
8
import inspect
9
import multiprocessing
10
import os
11
import signal
12
import sys
13
14
import tempfile
import threading
15
import traceback
Zhuohan Li's avatar
Zhuohan Li committed
16
import uuid
17
import warnings
18
import weakref
19
from collections.abc import Callable
20
from functools import cache, partial, wraps
21
from typing import TYPE_CHECKING, Any, TypeVar
Zhuohan Li's avatar
Zhuohan Li committed
22

23
import cloudpickle
24
import psutil
Zhuohan Li's avatar
Zhuohan Li committed
25
import torch
26

27
import vllm.envs as envs
28
from vllm.logger import enable_trace_function_call, init_logger
29
from vllm.ray.lazy_utils import is_in_ray_actor
30
31
32
33
34
35
36
37
38
39
40
41
42
43

# Import utilities from specialized modules for backward compatibility
from vllm.utils.argparse_utils import (
    FlexibleArgumentParser,
    SortedHelpFormatter,
    StoreBoolean,
)
from vllm.utils.math_utils import (
    cdiv,
    next_power_of_2,
    prev_power_of_2,
    round_down,
    round_up,
)
44
from vllm.utils.platform_utils import cuda_is_initialized, xpu_is_initialized
45

46
47
48
49
50
51
52
53
54
55
56
57
58
__all__ = [
    # Argparse utilities
    "FlexibleArgumentParser",
    "SortedHelpFormatter",
    "StoreBoolean",
    # Math utilities
    "cdiv",
    "next_power_of_2",
    "prev_power_of_2",
    "round_down",
    "round_up",
]

59
60
61
62
63
_DEPRECATED_MAPPINGS = {
    "cprofile": "profiling",
    "cprofile_context": "profiling",
    "get_open_port": "network_utils",
}
64
65
66


def __getattr__(name: str) -> Any:  # noqa: D401 - short deprecation docstring
67
68
69
    """Module-level getattr to handle deprecated utilities."""
    if name in _DEPRECATED_MAPPINGS:
        submodule_name = _DEPRECATED_MAPPINGS[name]
70
71
        warnings.warn(
            f"vllm.utils.{name} is deprecated and will be removed in a future version. "
72
            f"Use vllm.utils.{submodule_name}.{name} instead.",
73
74
75
            DeprecationWarning,
            stacklevel=2,
        )
76
77
        module = __import__(f"vllm.utils.{submodule_name}", fromlist=[submodule_name])
        return getattr(module, name)
78
79
80
81
82
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__() -> list[str]:
    # expose deprecated names in dir() for better UX/tab-completion
83
    return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
84
85


86
if TYPE_CHECKING:
87
88
    from argparse import Namespace

89
    from vllm.config import ModelConfig, VllmConfig
90
91
92
93
94
else:
    Namespace = object

    ModelConfig = object
    VllmConfig = object
95

96
97
logger = init_logger(__name__)

98
99
100
101
102
103
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Constants related to forcing the attention backend selection

# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"

# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

119

120
# ANSI color codes
121
122
CYAN = "\033[1;36m"
RESET = "\033[0;0m"
123

124

125
T = TypeVar("T")
126
U = TypeVar("U")
127

128

Woosuk Kwon's avatar
Woosuk Kwon committed
129
130
131
132
133
class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


134
135
136
137
138
class LayerBlockType(enum.Enum):
    attention = "attention"
    mamba = "mamba"


Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
141
142
class Counter:
    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
143
    def __next__(self) -> int:
144
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
145
        self.counter += 1
146
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
149

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
150

151

Zhuohan Li's avatar
Zhuohan Li committed
152
153
def random_uuid() -> str:
    return str(uuid.uuid4().hex)
154

155

156
157
158
159
160
161
162
163
164
165
def update_environment_variables(envs: dict[str, str]):
    for k, v in envs.items():
        if k in os.environ and os.environ[k] != v:
            logger.warning(
                "Overwriting environment variable %s from '%s' to '%s'",
                k,
                os.environ[k],
                v,
            )
        os.environ[k] = v
166

167

168
169
170
@cache
def is_pin_memory_available() -> bool:
    from vllm.platforms import current_platform
171

172
    return current_platform.is_pin_memory_available()
173
174


175
176
177
178
179
180
@cache
def is_uva_available() -> bool:
    """Check if Unified Virtual Addressing (UVA) is available."""
    # UVA requires pinned memory.
    # TODO: Add more requirements for UVA if needed.
    return is_pin_memory_available()
181
182


183
184
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
185
def init_cached_hf_modules() -> None:
186
187
188
189
    """
    Lazy initialization of the Hugging Face modules.
    """
    from transformers.dynamic_module_utils import init_hf_modules
190

191
    init_hf_modules()
192
193


194
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
195
196
197
198
    """Set up function tracing for the current thread,
    if enabled via the VLLM_TRACE_FUNCTION environment variable
    """

199
    if envs.VLLM_TRACE_FUNCTION:
200
        tmp_dir = tempfile.gettempdir()
201
202
        # add username to tmp_dir to avoid permission issues
        tmp_dir = os.path.join(tmp_dir, getpass.getuser())
203
204
205
206
207
208
209
210
        filename = (
            f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
            f"_thread_{threading.get_ident()}_"
            f"at_{datetime.datetime.now()}.log"
        ).replace(" ", "_")
        log_path = os.path.join(
            tmp_dir, "vllm", f"vllm-instance-{vllm_config.instance_id}", filename
        )
211
212
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        enable_trace_function_call(log_path)
213
214


215
216
217
def weak_bind(
    bound_method: Callable[..., Any],
) -> Callable[..., None]:
218
219
220
221
222
223
224
225
226
227
228
229
230
    """Make an instance method that weakly references
    its associated instance and no-ops once that
    instance is collected."""
    ref = weakref.ref(bound_method.__self__)  # type: ignore[attr-defined]
    unbound = bound_method.__func__  # type: ignore[attr-defined]

    def weak_bound(*args, **kwargs) -> None:
        if inst := ref():
            unbound(inst, *args, **kwargs)

    return weak_bound


231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
class AtomicCounter:
    """An atomic, thread-safe counter"""

    def __init__(self, initial=0):
        """Initialize a new atomic counter to given initial value"""
        self._value = initial
        self._lock = threading.Lock()

    def inc(self, num=1):
        """Atomically increment the counter by num and return the new value"""
        with self._lock:
            self._value += num
            return self._value

    def dec(self, num=1):
        """Atomically decrement the counter by num and return the new value"""
        with self._lock:
            self._value -= num
            return self._value

    @property
    def value(self):
        return self._value
254
255


256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def kill_process_tree(pid: int):
    """
    Kills all descendant processes of the given pid by sending SIGKILL.

    Args:
        pid (int): Process ID of the parent process
    """
    try:
        parent = psutil.Process(pid)
    except psutil.NoSuchProcess:
        return

    # Get all children recursively
    children = parent.children(recursive=True)

    # Send SIGKILL to all children first
    for child in children:
        with contextlib.suppress(ProcessLookupError):
            os.kill(child.pid, signal.SIGKILL)

    # Finally kill the parent
    with contextlib.suppress(ProcessLookupError):
        os.kill(pid, signal.SIGKILL)
279
280


281
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
282
def set_ulimit(target_soft_limit=65535):
283
    if sys.platform.startswith("win"):
284
285
286
287
        logger.info("Windows detected, skipping ulimit adjustment.")
        return

    import resource
288

289
290
291
292
293
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
294
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
295
296
        except ValueError as e:
            logger.warning(
297
298
                "Found ulimit of %s and failed to automatically increase "
                "with error %s. This can cause fd limit errors like "
299
                "`OSError: [Errno 24] Too many open files`. Consider "
300
301
302
303
                "increasing with ulimit -n",
                current_soft,
                e,
            )
304
305
306
307
308
309
310
311
312


# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
def get_exception_traceback():
    etype, value, tb = sys.exc_info()
    err_str = "".join(traceback.format_exception(etype, value, tb))
    return err_str


313
314
315
316
317
318
319
def _maybe_force_spawn():
    """Check if we need to force the use of the `spawn` multiprocessing start
    method.
    """
    if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
        return

320
321
    reasons = []
    if is_in_ray_actor():
322
323
324
325
        # even if we choose to spawn, we need to pass the ray address
        # to the subprocess so that it knows how to connect to the ray cluster.
        # env vars are inherited by subprocesses, even if we use spawn.
        import ray
326

327
        os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
328
329
330
331
332
333
        reasons.append("In a Ray actor and can only be spawned")

    if cuda_is_initialized():
        reasons.append("CUDA is initialized")
    elif xpu_is_initialized():
        reasons.append("XPU is initialized")
334

335
    if reasons:
336
337
338
        logger.warning(
            "We must use the `spawn` multiprocessing start method. "
            "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
339
            "See https://docs.vllm.ai/en/latest/usage/"
340
            "troubleshooting.html#python-multiprocessing "
341
342
343
            "for more information. Reasons: %s",
            "; ".join(reasons),
        )
344
345
346
347
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


def get_mp_context():
348
349
350
351
352
353
354
    """Get a multiprocessing context with a particular method (spawn or fork).
    By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
    determine the multiprocessing method (default is fork). However, under
    certain conditions, we may enforce spawn and override the value of
    VLLM_WORKER_MULTIPROC_METHOD.
    """
    _maybe_force_spawn()
355
356
    mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
    return multiprocessing.get_context(mp_method)
357
358
359


def bind_kv_cache(
360
361
    ctx: dict[str, Any],
    kv_cache: list[list[torch.Tensor]],  # [virtual_engine][layer_index]
362
    shared_kv_cache_layers: dict[str, str] | None = None,
363
364
365
366
367
368
369
370
371
372
373
) -> None:
    # Bind the kv_cache tensor to Attention modules, similar to
    # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
    # Special things handled here:
    # 1. Some models have non-attention layers, e.g., Jamba
    # 2. Pipeline parallelism, each rank only has a subset of layers
    # 3. Encoder attention has no kv cache
    # 4. Encoder-decoder models, encoder-decoder attention and decoder-only
    #    attention of the same layer (e.g., bart's decoder.layers.1.self_attn
    #    and decoder.layers.1.encoder_attn) is mapped to the same kv cache
    #    tensor
374
375
376
377
    # 5. Some models have attention layers that share kv cache with previous
    #    layers, this is specified through shared_kv_cache_layers
    if shared_kv_cache_layers is None:
        shared_kv_cache_layers = {}
378
379
    from vllm.attention import AttentionType
    from vllm.model_executor.models.utils import extract_layer_index
380

381
    layer_need_kv_cache = [
382
383
384
385
386
387
388
389
        layer_name
        for layer_name in ctx
        if (
            hasattr(ctx[layer_name], "attn_type")
            and ctx[layer_name].attn_type
            in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)
        )
        and ctx[layer_name].kv_sharing_target_layer_name is None
390
391
    ]
    layer_index_sorted = sorted(
392
393
        set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache)
    )
394
    for layer_name in layer_need_kv_cache:
395
        kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name))
396
397
398
399
        forward_ctx = ctx[layer_name]
        assert len(forward_ctx.kv_cache) == len(kv_cache)
        for ve, ve_kv_cache in enumerate(kv_cache):
            forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
400
401
    if shared_kv_cache_layers is not None:
        for layer_name, target_layer_name in shared_kv_cache_layers.items():
402
403
404
            assert extract_layer_index(target_layer_name) < extract_layer_index(
                layer_name
            ), "v0 doesn't support interleaving kv sharing"
405
            ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
406
407


408
409
def run_method(
    obj: Any,
410
    method: str | bytes | Callable,
411
412
413
    args: tuple[Any],
    kwargs: dict[str, Any],
) -> Any:
414
415
416
417
418
419
420
421
422
423
424
425
426
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
427
428
429
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
430
431
432
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453


def import_pynvml():
    """
    Historical comments:

    libnvml.so is the library behind nvidia-smi, and
    pynvml is a Python wrapper around it. We use it to get GPU
    status without initializing CUDA context in the current process.
    Historically, there are two packages that provide pynvml:
    - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
        wrapper. It is a dependency of vLLM, and is installed when users
        install vLLM. It provides a Python module named `pynvml`.
    - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
        Prior to version 12.0, it also provides a Python module `pynvml`,
        and therefore conflicts with the official one. What's worse,
        the module is a Python package, and has higher priority than
        the official one which is a standalone Python file.
        This causes errors when both of them are installed.
        Starting from version 12.0, it migrates to a new module
        named `pynvml_utils` to avoid the conflict.
454
455
456
457
458
459
460
    It is so confusing that many packages in the community use the
    unofficial one by mistake, and we have to handle this case.
    For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
    one, and it will cause errors, see the issue
    https://github.com/vllm-project/vllm/issues/12847 for example.
    After all the troubles, we decide to copy the official `pynvml`
    module to our codebase, and use it directly.
461
    """
462
    import vllm.third_party.pynvml as pynvml
463

464
    return pynvml
465
466


467
def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    """
    A replacement for `abc.ABC`.
    When we use `abc.ABC`, subclasses will fail to instantiate
    if they do not implement all abstract methods.
    Here, we only require `raise NotImplementedError` in the
    base class, and log a warning if the method is not implemented
    in the subclass.
    """

    original_init = cls.__init__

    def find_unimplemented_methods(self: object):
        unimplemented_methods = []
        for attr_name in dir(self):
            # bypass inner method
483
            if attr_name.startswith("_"):
484
485
486
487
488
489
490
491
492
493
494
495
496
                continue

            try:
                attr = getattr(self, attr_name)
                # get the func of callable method
                if callable(attr):
                    attr_func = attr.__func__
            except AttributeError:
                continue
            src = inspect.getsource(attr_func)
            if "NotImplementedError" in src:
                unimplemented_methods.append(attr_name)
        if unimplemented_methods:
497
498
            method_names = ",".join(unimplemented_methods)
            msg = f"Methods {method_names} not implemented in {self}"
499
            logger.debug(msg)
500
501
502
503
504
505

    @wraps(original_init)
    def wrapped_init(self, *args, **kwargs) -> None:
        original_init(self, *args, **kwargs)
        find_unimplemented_methods(self)

506
    type.__setattr__(cls, "__init__", wrapped_init)
507
    return cls
508
509


510
511
# Only relevant for models using ALiBi (e.g, MPT)
def check_use_alibi(model_config: ModelConfig) -> bool:
512
    cfg = model_config.hf_text_config
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    return (
        getattr(cfg, "alibi", False)  # Falcon
        or (
            "BloomForCausalLM" in getattr(model_config.hf_config, "architectures", [])
        )  # Bloom
        or getattr(cfg, "position_encoding_type", "") == "alibi"  # codellm_1b_alibi
        or (
            hasattr(cfg, "attn_config")  # MPT
            and (
                (
                    isinstance(cfg.attn_config, dict)
                    and cfg.attn_config.get("alibi", False)
                )
                or (
                    not isinstance(cfg.attn_config, dict)
                    and getattr(cfg.attn_config, "alibi", False)
                )
            )
        )
    )
533
534


535
def length_from_prompt_token_ids_or_embeds(
536
537
    prompt_token_ids: list[int] | None,
    prompt_embeds: torch.Tensor | None,
538
) -> int:
539
    """Calculate the request length (in number of tokens) give either
540
541
    prompt_token_ids or prompt_embeds.
    """
542
543
    prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids)
    prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds)
544
545
546

    if prompt_token_len is None:
        if prompt_embeds_len is None:
547
            raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.")
548
549
        return prompt_embeds_len
    else:
550
        if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len:
551
552
553
            raise ValueError(
                "Prompt token ids and prompt embeds had different lengths"
                f" prompt_token_ids={prompt_token_len}"
554
555
                f" prompt_embeds={prompt_embeds_len}"
            )
556
        return prompt_token_len