utils.py 9.35 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
import enum
2
import os
3
import socket
4
import subprocess
Zhuohan Li's avatar
Zhuohan Li committed
5
import uuid
6
from platform import uname
7
8
from typing import List, Tuple, Union
from packaging.version import parse, Version
Zhuohan Li's avatar
Zhuohan Li committed
9

10
import psutil
Zhuohan Li's avatar
Zhuohan Li committed
11
import torch
12
13
14
15
16
17
18
19
20
21
import asyncio
from functools import partial
from typing import (
    Awaitable,
    Callable,
    TypeVar,
)
from collections import OrderedDict
from typing import Any, Hashable, Optional

22
23
from vllm.logger import init_logger

24
T = TypeVar("T")
25
26
27
28
29
30
31
32
logger = init_logger(__name__)

STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
    "fp8_e5m2": torch.uint8,
}
Zhuohan Li's avatar
Zhuohan Li committed
33

Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
39
40
41
42
43
44

class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


class Counter:

    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
45
    def __next__(self) -> int:
46
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
47
        self.counter += 1
48
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
52

53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class LRUCache:

    def __init__(self, capacity: int):
        self.cache = OrderedDict()
        self.capacity = capacity

    def __contains__(self, key: Hashable) -> bool:
        return key in self.cache

    def __len__(self) -> int:
        return len(self.cache)

    def __getitem__(self, key: Hashable) -> Any:
        return self.get(key)

    def __setitem__(self, key: Hashable, value: Any) -> None:
        self.put(key, value)

    def __delitem__(self, key: Hashable) -> None:
        self.pop(key)

    def touch(self, key: Hashable) -> None:
        self.cache.move_to_end(key)

    def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
        if key in self.cache:
            value = self.cache[key]
            self.cache.move_to_end(key)
        else:
            value = default_value
        return value

    def put(self, key: Hashable, value: Any) -> None:
        self.cache[key] = value
        self.cache.move_to_end(key)
        self._remove_old_if_needed()

    def _on_remove(self, key: Hashable, value: Any):
        pass

    def remove_oldest(self):
        if not self.cache:
            return
        key, value = self.cache.popitem(last=False)
        self._on_remove(key, value)

    def _remove_old_if_needed(self) -> None:
        while len(self.cache) > self.capacity:
            self.remove_oldest()

    def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
        run_on_remove = key in self.cache
        value = self.cache.pop(key, default_value)
        if run_on_remove:
            self._on_remove(key, value)
        return value

    def clear(self):
        while len(self.cache) > 0:
            self.remove_oldest()
        self.cache.clear()


117
118
119
120
def is_hip() -> bool:
    return torch.version.hip is not None


121
122
123
124
125
126
127
128
def is_neuron() -> bool:
    try:
        import transformers_neuronx
    except ImportError:
        transformers_neuronx = None
    return transformers_neuronx is not None


129
130
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
    """Returns the maximum shared memory per thread block in bytes."""
131
132
133
134
    # NOTE: This import statement should be executed lazily since
    # the Neuron-X backend does not have the `cuda_utils` module.
    from vllm._C import cuda_utils

135
136
137
138
    max_shared_mem = cuda_utils.get_max_shared_memory_per_block_device_attribute(
        gpu)
    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py will fail
    assert max_shared_mem > 0, "max_shared_mem can not be zero"
139
140
141
    return int(max_shared_mem)


142
def get_cpu_memory() -> int:
143
    """Returns the total CPU memory of the node in bytes."""
144
    return psutil.virtual_memory().total
Zhuohan Li's avatar
Zhuohan Li committed
145
146
147
148


def random_uuid() -> str:
    return str(uuid.uuid4().hex)
149

150

151
152
153
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()
154
155


156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
    """Take a blocking function, and run it on in an executor thread.

    This function prevents the blocking function from blocking the
    asyncio event loop.
    The code in this function needs to be thread safe.
    """

    def _async_wrapper(*args, **kwargs) -> asyncio.Future:
        loop = asyncio.get_event_loop()
        p_func = partial(func, *args, **kwargs)
        return loop.run_in_executor(executor=None, func=p_func)

    return _async_wrapper


172
def get_ip() -> str:
173
    # try ipv4
174
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
175
176
177
178
179
180
181
182
    try:
        s.connect(("dns.google", 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except OSError:
        # try ipv6
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
        s.connect(("dns.google", 80))
        return s.getsockname()[0]
183
184


185
186
187
188
def get_distributed_init_method(ip: str, port: int) -> str:
    return f"tcp://{ip}:{port}"


189
def get_open_port() -> int:
190
191
192
193
194
195
196
197
198
199
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
200
201
202
203


def set_cuda_visible_devices(device_ids: List[int]) -> None:
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
204
205


206
def get_nvcc_cuda_version() -> Optional[Version]:
207
208
209
    cuda_home = os.environ.get('CUDA_HOME')
    if not cuda_home:
        cuda_home = '/usr/local/cuda'
210
211
212
213
214
215
216
217
        if os.path.isfile(cuda_home + '/bin/nvcc'):
            logger.info(
                f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.'
            )
        else:
            logger.warning(
                f'Not found nvcc in {cuda_home}. Skip cuda version check!')
            return None
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
                                          universal_newlines=True)
    output = nvcc_output.split()
    release_idx = output.index("release") + 1
    nvcc_cuda_version = parse(output[release_idx].split(",")[0])
    return nvcc_cuda_version


def _generate_random_fp8_e5m2(
    tensor: torch.tensor,
    low: float,
    high: float,
) -> None:
    # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
    # it may occur Inf or NaN if we directly use torch.randint
    # to generate random data for fp8 data.
234
    # For example, s.11111.00 in fp8e5m2 format represents Inf.
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    #     | E4M3        | E5M2
    #-----|-------------|-------------------
    # Inf | N/A         | s.11111.00
    # NaN | s.1111.111  | s.11111.{01,10,11}
    from vllm._C import cache_ops
    tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
    tensor_tmp.uniform_(low, high)
    cache_ops.convert_fp8_e5m2(tensor_tmp, tensor)
    del tensor_tmp


def create_kv_caches_with_random(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
    seed: Optional[int] = 0,
    device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    torch.random.manual_seed(seed)
258
259
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

    if isinstance(cache_dtype, str):
        if cache_dtype == "auto":
            if isinstance(model_dtype, str):
                torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
            elif isinstance(model_dtype, torch.dtype):
                torch_dtype = model_dtype
            else:
                raise ValueError(f"Invalid model dtype: {model_dtype}")
        elif cache_dtype in ["half", "bfloat16", "float"]:
            torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
        elif cache_dtype == "fp8_e5m2":
            torch_dtype = torch.uint8
        else:
            raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
    elif isinstance(cache_dtype, torch.dtype):
        torch_dtype = cache_dtype
    else:
        raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")

    scale = head_size**-0.5
    x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_caches = []
    for _ in range(num_layers):
        key_cache = torch.empty(size=key_cache_shape,
                                dtype=torch_dtype,
                                device=device)
288
        if cache_dtype == 'fp8_e5m2':
289
            _generate_random_fp8_e5m2(key_cache, -scale, scale)
290
291
292
293
294
        elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
            key_cache.uniform_(-scale, scale)
        else:
            raise ValueError(
                f"Does not support key cache of type {cache_dtype}")
295
296
297
298
299
300
301
302
        key_caches.append(key_cache)

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
    value_caches = []
    for _ in range(num_layers):
        value_cache = torch.empty(size=value_cache_shape,
                                  dtype=torch_dtype,
                                  device=device)
303
        if cache_dtype == 'fp8_e5m2':
304
            _generate_random_fp8_e5m2(value_cache, -scale, scale)
305
306
307
308
309
        elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
            value_cache.uniform_(-scale, scale)
        else:
            raise ValueError(
                f"Does not support value cache of type {cache_dtype}")
310
311
        value_caches.append(value_cache)
    return key_caches, value_caches