utils.py 15.5 KB
Newer Older
1
import asyncio
Woosuk Kwon's avatar
Woosuk Kwon committed
2
import enum
3
import gc
4
import os
5
import socket
6
import subprocess
Zhuohan Li's avatar
Zhuohan Li committed
7
import uuid
8
import warnings
9
from collections import defaultdict
10
from functools import lru_cache, partial
11
from platform import uname
12
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
13
14
                    Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
                    Union)
Zhuohan Li's avatar
Zhuohan Li committed
15

16
import psutil
Zhuohan Li's avatar
Zhuohan Li committed
17
import torch
18
from packaging.version import Version, parse
19

20
21
from vllm.logger import init_logger

22
T = TypeVar("T")
23
24
25
26
27
28
logger = init_logger(__name__)

STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
29
    "fp8": torch.uint8,
30
}
Zhuohan Li's avatar
Zhuohan Li committed
31

Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
35
36
37
38
39
40
41
42

class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


class Counter:

    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
43
    def __next__(self) -> int:
44
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
45
        self.counter += 1
46
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
50

51

52
class LRUCache(Generic[T]):
53
54

    def __init__(self, capacity: int):
55
        self.cache: OrderedDict[Hashable, T] = OrderedDict()
56
57
58
59
60
61
62
63
        self.capacity = capacity

    def __contains__(self, key: Hashable) -> bool:
        return key in self.cache

    def __len__(self) -> int:
        return len(self.cache)

64
    def __getitem__(self, key: Hashable) -> Optional[T]:
65
66
        return self.get(key)

67
    def __setitem__(self, key: Hashable, value: T) -> None:
68
69
70
71
72
73
74
75
        self.put(key, value)

    def __delitem__(self, key: Hashable) -> None:
        self.pop(key)

    def touch(self, key: Hashable) -> None:
        self.cache.move_to_end(key)

76
77
78
    def get(self,
            key: Hashable,
            default_value: Optional[T] = None) -> Optional[T]:
79
        if key in self.cache:
80
            value: Optional[T] = self.cache[key]
81
82
83
84
85
            self.cache.move_to_end(key)
        else:
            value = default_value
        return value

86
    def put(self, key: Hashable, value: T) -> None:
87
88
89
90
        self.cache[key] = value
        self.cache.move_to_end(key)
        self._remove_old_if_needed()

91
    def _on_remove(self, key: Hashable, value: Optional[T]):
92
93
94
95
96
97
98
99
100
101
102
103
        pass

    def remove_oldest(self):
        if not self.cache:
            return
        key, value = self.cache.popitem(last=False)
        self._on_remove(key, value)

    def _remove_old_if_needed(self) -> None:
        while len(self.cache) > self.capacity:
            self.remove_oldest()

104
105
106
    def pop(self,
            key: Hashable,
            default_value: Optional[T] = None) -> Optional[T]:
107
        run_on_remove = key in self.cache
108
        value: Optional[T] = self.cache.pop(key, default_value)
109
110
111
112
113
114
115
116
117
118
        if run_on_remove:
            self._on_remove(key, value)
        return value

    def clear(self):
        while len(self.cache) > 0:
            self.remove_oldest()
        self.cache.clear()


119
120
121
122
def is_hip() -> bool:
    return torch.version.hip is not None


123
124
@lru_cache(maxsize=None)
def is_cpu() -> bool:
125
126
127
128
129
    from importlib.metadata import PackageNotFoundError, version
    try:
        return "cpu" in version("vllm")
    except PackageNotFoundError:
        return False
130
131


132
@lru_cache(maxsize=None)
133
134
135
136
137
138
139
140
def is_neuron() -> bool:
    try:
        import transformers_neuronx
    except ImportError:
        transformers_neuronx = None
    return transformers_neuronx is not None


141
@lru_cache(maxsize=None)
142
143
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
    """Returns the maximum shared memory per thread block in bytes."""
144
145
146
147
    # NOTE: This import statement should be executed lazily since
    # the Neuron-X backend does not have the `cuda_utils` module.
    from vllm._C import cuda_utils

148
149
150
151
    max_shared_mem = (
        cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
    # will fail
152
    assert max_shared_mem > 0, "max_shared_mem can not be zero"
153
154
155
    return int(max_shared_mem)


156
def get_cpu_memory() -> int:
157
    """Returns the total CPU memory of the node in bytes."""
158
    return psutil.virtual_memory().total
Zhuohan Li's avatar
Zhuohan Li committed
159
160
161
162


def random_uuid() -> str:
    return str(uuid.uuid4().hex)
163

164

165
@lru_cache(maxsize=None)
166
167
168
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()
169
170


171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
    """Take a blocking function, and run it on in an executor thread.

    This function prevents the blocking function from blocking the
    asyncio event loop.
    The code in this function needs to be thread safe.
    """

    def _async_wrapper(*args, **kwargs) -> asyncio.Future:
        loop = asyncio.get_event_loop()
        p_func = partial(func, *args, **kwargs)
        return loop.run_in_executor(executor=None, func=p_func)

    return _async_wrapper


187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def merge_async_iterators(
        *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
    """Merge multiple asynchronous iterators into a single iterator.

    This method handle the case where some iterators finish before others.
    When it yields, it yields a tuple (i, item) where i is the index of the
    iterator that yields the item.
    """
    queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()

    finished = [False] * len(iterators)

    async def producer(i: int, iterator: AsyncIterator[T]):
        try:
            async for item in iterator:
                await queue.put((i, item))
        except Exception as e:
            await queue.put(e)
        finished[i] = True

    _tasks = [
        asyncio.create_task(producer(i, iterator))
        for i, iterator in enumerate(iterators)
    ]

    async def consumer():
        while not all(finished) or not queue.empty():
            item = await queue.get()
            if isinstance(item, Exception):
                raise item
            yield item
        await asyncio.gather(*_tasks)

    return consumer()


223
def get_ip() -> str:
224
225
226
227
228
229
    host_ip = os.environ.get("HOST_IP")
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

230
    # try ipv4
231
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
232
    try:
233
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
234
        return s.getsockname()[0]
235
236
237
238
239
    except Exception:
        pass

    # try ipv6
    try:
240
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
241
242
243
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
244
        return s.getsockname()[0]
245
246
247
248
249
250
251
252
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
        "The value can be set by the environment variable HOST_IP.",
        stacklevel=2)
    return "0.0.0.0"
253
254


255
def get_distributed_init_method(ip: str, port: int) -> str:
256
257
258
    # Brackets are not permitted in ipv4 addresses,
    # see https://github.com/python/cpython/issues/103848
    return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
259
260


261
def get_open_port() -> int:
262
263
264
265
266
267
268
269
270
271
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
272
273


274
275
276
277
278
279
def update_environment_variables(envs: Dict[str, str]):
    for k, v in envs.items():
        if k in os.environ:
            logger.warning(f"Overwriting environment variable {k} "
                           f"from '{os.environ[k]}' to '{v}'")
        os.environ[k] = v
280
281


282
283
284
285
286
287
288
289
290
291
def chunk_list(lst, chunk_size):
    """Yield successive chunk_size chunks from lst."""
    return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]


def cdiv(a: int, b: int) -> int:
    """Ceiling division."""
    return -(a // -b)


292
@lru_cache(maxsize=None)
293
def get_nvcc_cuda_version() -> Optional[Version]:
294
295
296
    cuda_home = os.environ.get('CUDA_HOME')
    if not cuda_home:
        cuda_home = '/usr/local/cuda'
297
        if os.path.isfile(cuda_home + '/bin/nvcc'):
298
299
            logger.info(f'CUDA_HOME is not found in the environment. '
                        f'Using {cuda_home} as CUDA_HOME.')
300
301
302
303
        else:
            logger.warning(
                f'Not found nvcc in {cuda_home}. Skip cuda version check!')
            return None
304
305
306
307
308
309
310
311
    nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
                                          universal_newlines=True)
    output = nvcc_output.split()
    release_idx = output.index("release") + 1
    nvcc_cuda_version = parse(output[release_idx].split(",")[0])
    return nvcc_cuda_version


312
def _generate_random_fp8(
313
314
315
316
317
318
319
    tensor: torch.tensor,
    low: float,
    high: float,
) -> None:
    # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
    # it may occur Inf or NaN if we directly use torch.randint
    # to generate random data for fp8 data.
320
    # For example, s.11111.00 in fp8e5m2 format represents Inf.
321
322
323
324
    #     | E4M3        | E5M2
    #-----|-------------|-------------------
    # Inf | N/A         | s.11111.00
    # NaN | s.1111.111  | s.11111.{01,10,11}
325
    from vllm import _custom_ops as ops
326
327
    tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
    tensor_tmp.uniform_(low, high)
328
    ops.convert_fp8(tensor_tmp, tensor)
329
330
331
332
333
334
335
336
337
338
339
    del tensor_tmp


def create_kv_caches_with_random(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
340
    seed: int = 0,
341
342
343
    device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    torch.random.manual_seed(seed)
344
345
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
346
347
348
349
350
351
352
353
354
355
356

    if isinstance(cache_dtype, str):
        if cache_dtype == "auto":
            if isinstance(model_dtype, str):
                torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
            elif isinstance(model_dtype, torch.dtype):
                torch_dtype = model_dtype
            else:
                raise ValueError(f"Invalid model dtype: {model_dtype}")
        elif cache_dtype in ["half", "bfloat16", "float"]:
            torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
357
        elif cache_dtype == "fp8":
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
            torch_dtype = torch.uint8
        else:
            raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
    elif isinstance(cache_dtype, torch.dtype):
        torch_dtype = cache_dtype
    else:
        raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")

    scale = head_size**-0.5
    x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_caches = []
    for _ in range(num_layers):
        key_cache = torch.empty(size=key_cache_shape,
                                dtype=torch_dtype,
                                device=device)
374
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
375
            key_cache.uniform_(-scale, scale)
376
377
        elif cache_dtype == 'fp8':
            _generate_random_fp8(key_cache, -scale, scale)
378
379
380
        else:
            raise ValueError(
                f"Does not support key cache of type {cache_dtype}")
381
382
383
384
385
386
387
388
        key_caches.append(key_cache)

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
    value_caches = []
    for _ in range(num_layers):
        value_cache = torch.empty(size=value_cache_shape,
                                  dtype=torch_dtype,
                                  device=device)
389
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
390
            value_cache.uniform_(-scale, scale)
391
392
        elif cache_dtype == 'fp8':
            _generate_random_fp8(value_cache, -scale, scale)
393
394
395
        else:
            raise ValueError(
                f"Does not support value cache of type {cache_dtype}")
396
397
        value_caches.append(value_cache)
    return key_caches, value_caches
398
399


400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
@lru_cache
def print_warning_once(msg: str) -> None:
    logger.warning(msg)


@lru_cache(maxsize=None)
def is_pin_memory_available() -> bool:

    if in_wsl():
        # Pinning memory in WSL is not supported.
        # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
        print_warning_once("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
        return False
    elif is_neuron():
        print_warning_once("Pin memory is not supported on Neuron.")
        return False
417
418
    elif is_cpu():
        return False
419
420
421
422
    return True


class CudaMemoryProfiler:
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443

    def __init__(self, device=None):
        self.device = device

    def current_memory_usage(self) -> float:
        # Return the memory usage in bytes.
        torch.cuda.reset_peak_memory_stats(self.device)
        mem = torch.cuda.max_memory_allocated(self.device)
        return mem

    def __enter__(self):
        self.initial_memory = self.current_memory_usage()
        # This allows us to call methods of the context manager if needed
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.final_memory = self.current_memory_usage()
        self.consumed_memory = self.final_memory - self.initial_memory

        # Force garbage collection
        gc.collect()
444
445


446
def str_to_int_tuple(s: str) -> Tuple[int, ...]:
447
448
449
450
451
452
453
454
455
    """Convert a string to a tuple of integers."""
    try:
        return tuple(map(int, s.split(",")))
    except ValueError as e:
        raise ValueError(
            "String must be a series of integers separated by commas "
            f"(e.g., 1, 2, 3). Given input: {s}") from e


456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
    assert len(x) <= max_len
    return x + [pad] * (max_len - len(x))


def make_tensor_with_pad(
    x: List[List[int]],
    max_len: int,
    pad: int,
    dtype: torch.dtype,
    device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
    """Make a padded tensor of a 2D inputs.

    The padding is applied to the end of each inner list until it reaches
    `max_len`.
    """
    padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
    return torch.tensor(padded_x, dtype=dtype, device=device)


def async_tensor_h2d(
    data: list,
    dtype: torch.dtype,
    target_device: Union[str, torch.device],
    pin_memory: bool,
) -> torch.Tensor:
    """Asynchronously create a tensor and copy it from host to device."""
    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
    return t.to(device=target_device, non_blocking=True)


def maybe_expand_dim(tensor: torch.Tensor,
                     target_dims: int,
                     size: int = 1) -> torch.Tensor:
    """Expand the tensor to the target_dims."""
    if tensor.ndim < target_dims:
        tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
    return tensor
495
496


497
498
def merge_dicts(dict1: Dict[Any, List[Any]],
                dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
499
500
501
502
503
504
505
506
507
508
509
510
511
    """Merge 2 dicts that have key -> List of items.
    
    When a key conflicts, the values in dict1 is prioritized.
    """
    merged_dict = defaultdict(list)

    for key, value in dict1.items():
        merged_dict[key].extend(value)

    for key, value in dict2.items():
        merged_dict[key].extend(value)

    return dict(merged_dict)
512
513
514
515
516
517
518
519


def init_cached_hf_modules():
    """
    Lazy initialization of the Hugging Face modules.
    """
    from transformers.dynamic_module_utils import init_hf_modules
    init_hf_modules()