__init__.py 5.56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
import enum
5
import inspect
6
import threading
Zhuohan Li's avatar
Zhuohan Li committed
7
import uuid
8
import warnings
9
10
from functools import wraps
from typing import Any, TypeVar
Zhuohan Li's avatar
Zhuohan Li committed
11
12

import torch
13

14
from vllm.logger import init_logger
15

16
17
18
_DEPRECATED_MAPPINGS = {
    "cprofile": "profiling",
    "cprofile_context": "profiling",
19
    # Used by lm-eval
20
21
    "get_open_port": "network_utils",
}
22
23
24


def __getattr__(name: str) -> Any:  # noqa: D401 - short deprecation docstring
25
26
27
    """Module-level getattr to handle deprecated utilities."""
    if name in _DEPRECATED_MAPPINGS:
        submodule_name = _DEPRECATED_MAPPINGS[name]
28
29
        warnings.warn(
            f"vllm.utils.{name} is deprecated and will be removed in a future version. "
30
            f"Use vllm.utils.{submodule_name}.{name} instead.",
31
32
33
            DeprecationWarning,
            stacklevel=2,
        )
34
35
        module = __import__(f"vllm.utils.{submodule_name}", fromlist=[submodule_name])
        return getattr(module, name)
36
37
38
39
40
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__() -> list[str]:
    # expose deprecated names in dir() for better UX/tab-completion
41
    return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
42
43


44
45
logger = init_logger(__name__)

46
47
48
49
50
51
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Constants related to forcing the attention backend selection

# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"

# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"

67

68
69
T = TypeVar("T")

70

Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74
75
class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


76
77
78
79
80
class LayerBlockType(enum.Enum):
    attention = "attention"
    mamba = "mamba"


Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
84
class Counter:
    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
85
    def __next__(self) -> int:
86
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
87
        self.counter += 1
88
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
92

93

Cyrus Leung's avatar
Cyrus Leung committed
94
95
class AtomicCounter:
    """An atomic, thread-safe counter"""
96

Cyrus Leung's avatar
Cyrus Leung committed
97
98
99
100
    def __init__(self, initial=0):
        """Initialize a new atomic counter to given initial value"""
        self._value = initial
        self._lock = threading.Lock()
101

Cyrus Leung's avatar
Cyrus Leung committed
102
103
104
105
106
    def inc(self, num=1):
        """Atomically increment the counter by num and return the new value"""
        with self._lock:
            self._value += num
            return self._value
107

Cyrus Leung's avatar
Cyrus Leung committed
108
109
110
111
112
    def dec(self, num=1):
        """Atomically decrement the counter by num and return the new value"""
        with self._lock:
            self._value -= num
            return self._value
113

Cyrus Leung's avatar
Cyrus Leung committed
114
115
116
    @property
    def value(self):
        return self._value
117
118


Cyrus Leung's avatar
Cyrus Leung committed
119
120
def random_uuid() -> str:
    return str(uuid.uuid4().hex)
121
122


123
def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    """
    A replacement for `abc.ABC`.
    When we use `abc.ABC`, subclasses will fail to instantiate
    if they do not implement all abstract methods.
    Here, we only require `raise NotImplementedError` in the
    base class, and log a warning if the method is not implemented
    in the subclass.
    """

    original_init = cls.__init__

    def find_unimplemented_methods(self: object):
        unimplemented_methods = []
        for attr_name in dir(self):
            # bypass inner method
139
            if attr_name.startswith("_"):
140
141
142
143
144
145
146
147
148
149
150
151
152
                continue

            try:
                attr = getattr(self, attr_name)
                # get the func of callable method
                if callable(attr):
                    attr_func = attr.__func__
            except AttributeError:
                continue
            src = inspect.getsource(attr_func)
            if "NotImplementedError" in src:
                unimplemented_methods.append(attr_name)
        if unimplemented_methods:
153
154
            method_names = ",".join(unimplemented_methods)
            msg = f"Methods {method_names} not implemented in {self}"
155
            logger.debug(msg)
156
157
158
159
160
161

    @wraps(original_init)
    def wrapped_init(self, *args, **kwargs) -> None:
        original_init(self, *args, **kwargs)
        find_unimplemented_methods(self)

162
    type.__setattr__(cls, "__init__", wrapped_init)
163
    return cls
164
165


166
def length_from_prompt_token_ids_or_embeds(
167
168
    prompt_token_ids: list[int] | None,
    prompt_embeds: torch.Tensor | None,
169
) -> int:
170
    """Calculate the request length (in number of tokens) give either
171
172
    prompt_token_ids or prompt_embeds.
    """
173
174
    prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids)
    prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds)
175
176
177

    if prompt_token_len is None:
        if prompt_embeds_len is None:
178
            raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.")
179
180
        return prompt_embeds_len
    else:
181
        if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len:
182
183
184
            raise ValueError(
                "Prompt token ids and prompt embeds had different lengths"
                f" prompt_token_ids={prompt_token_len}"
185
186
                f" prompt_embeds={prompt_embeds_len}"
            )
187
        return prompt_token_len