utils.py 19.7 KB
Newer Older
1
import asyncio
2
import datetime
Woosuk Kwon's avatar
Woosuk Kwon committed
3
import enum
4
import gc
5
import glob
6
import os
7
import socket
8
import subprocess
9
10
import tempfile
import threading
Zhuohan Li's avatar
Zhuohan Li committed
11
import uuid
12
import warnings
13
from collections import defaultdict
14
from functools import lru_cache, partial
15
from platform import uname
16
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
17
18
                    Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
                    Union)
Zhuohan Li's avatar
Zhuohan Li committed
19

20
import psutil
Zhuohan Li's avatar
Zhuohan Li committed
21
import torch
22
from packaging.version import Version, parse
23

24
from vllm.logger import enable_trace_function_call, init_logger
25

26
T = TypeVar("T")
27
28
29
30
31
32
logger = init_logger(__name__)

STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
33
    "fp8": torch.uint8,
34
}
Zhuohan Li's avatar
Zhuohan Li committed
35

Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
39
40
41
42
43
44
45
46

class Device(enum.Enum):
    GPU = enum.auto()
    CPU = enum.auto()


class Counter:

    def __init__(self, start: int = 0) -> None:
        self.counter = start

Woosuk Kwon's avatar
Woosuk Kwon committed
47
    def __next__(self) -> int:
48
        i = self.counter
Woosuk Kwon's avatar
Woosuk Kwon committed
49
        self.counter += 1
50
        return i
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53

    def reset(self) -> None:
        self.counter = 0
Zhuohan Li's avatar
Zhuohan Li committed
54

55

56
class LRUCache(Generic[T]):
57
58

    def __init__(self, capacity: int):
59
        self.cache: OrderedDict[Hashable, T] = OrderedDict()
60
61
62
63
64
65
66
67
        self.capacity = capacity

    def __contains__(self, key: Hashable) -> bool:
        return key in self.cache

    def __len__(self) -> int:
        return len(self.cache)

68
    def __getitem__(self, key: Hashable) -> Optional[T]:
69
70
        return self.get(key)

71
    def __setitem__(self, key: Hashable, value: T) -> None:
72
73
74
75
76
77
78
79
        self.put(key, value)

    def __delitem__(self, key: Hashable) -> None:
        self.pop(key)

    def touch(self, key: Hashable) -> None:
        self.cache.move_to_end(key)

80
81
82
    def get(self,
            key: Hashable,
            default_value: Optional[T] = None) -> Optional[T]:
83
        if key in self.cache:
84
            value: Optional[T] = self.cache[key]
85
86
87
88
89
            self.cache.move_to_end(key)
        else:
            value = default_value
        return value

90
    def put(self, key: Hashable, value: T) -> None:
91
92
93
94
        self.cache[key] = value
        self.cache.move_to_end(key)
        self._remove_old_if_needed()

95
    def _on_remove(self, key: Hashable, value: Optional[T]):
96
97
98
99
100
101
102
103
104
105
106
107
        pass

    def remove_oldest(self):
        if not self.cache:
            return
        key, value = self.cache.popitem(last=False)
        self._on_remove(key, value)

    def _remove_old_if_needed(self) -> None:
        while len(self.cache) > self.capacity:
            self.remove_oldest()

108
109
110
    def pop(self,
            key: Hashable,
            default_value: Optional[T] = None) -> Optional[T]:
111
        run_on_remove = key in self.cache
112
        value: Optional[T] = self.cache.pop(key, default_value)
113
114
115
116
117
118
119
120
121
122
        if run_on_remove:
            self._on_remove(key, value)
        return value

    def clear(self):
        while len(self.cache) > 0:
            self.remove_oldest()
        self.cache.clear()


123
124
125
126
def is_hip() -> bool:
    return torch.version.hip is not None


127
128
@lru_cache(maxsize=None)
def is_cpu() -> bool:
129
130
131
132
133
    from importlib.metadata import PackageNotFoundError, version
    try:
        return "cpu" in version("vllm")
    except PackageNotFoundError:
        return False
134
135


136
@lru_cache(maxsize=None)
137
138
139
140
141
142
143
144
def is_neuron() -> bool:
    try:
        import transformers_neuronx
    except ImportError:
        transformers_neuronx = None
    return transformers_neuronx is not None


145
@lru_cache(maxsize=None)
146
147
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
    """Returns the maximum shared memory per thread block in bytes."""
148
149
150
151
    # NOTE: This import statement should be executed lazily since
    # the Neuron-X backend does not have the `cuda_utils` module.
    from vllm._C import cuda_utils

152
153
154
155
    max_shared_mem = (
        cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
    # will fail
156
    assert max_shared_mem > 0, "max_shared_mem can not be zero"
157
158
159
    return int(max_shared_mem)


160
def get_cpu_memory() -> int:
161
    """Returns the total CPU memory of the node in bytes."""
162
    return psutil.virtual_memory().total
Zhuohan Li's avatar
Zhuohan Li committed
163
164
165
166


def random_uuid() -> str:
    return str(uuid.uuid4().hex)
167

168

169
170
171
172
173
174
175
176
177
178
179
@lru_cache(maxsize=None)
def get_vllm_instance_id():
    """
    If the environment variable VLLM_INSTANCE_ID is set, return it.
    Otherwise, return a random UUID.
    Instance id represents an instance of the VLLM. All processes in the same
    instance should have the same instance id.
    """
    return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}")


180
@lru_cache(maxsize=None)
181
182
183
def in_wsl() -> bool:
    # Reference: https://github.com/microsoft/WSL/issues/4071
    return "microsoft" in " ".join(uname()).lower()
184
185


186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
    """Take a blocking function, and run it on in an executor thread.

    This function prevents the blocking function from blocking the
    asyncio event loop.
    The code in this function needs to be thread safe.
    """

    def _async_wrapper(*args, **kwargs) -> asyncio.Future:
        loop = asyncio.get_event_loop()
        p_func = partial(func, *args, **kwargs)
        return loop.run_in_executor(executor=None, func=p_func)

    return _async_wrapper


202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def merge_async_iterators(
        *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
    """Merge multiple asynchronous iterators into a single iterator.

    This method handle the case where some iterators finish before others.
    When it yields, it yields a tuple (i, item) where i is the index of the
    iterator that yields the item.
    """
    queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()

    finished = [False] * len(iterators)

    async def producer(i: int, iterator: AsyncIterator[T]):
        try:
            async for item in iterator:
                await queue.put((i, item))
        except Exception as e:
            await queue.put(e)
        finished[i] = True

    _tasks = [
        asyncio.create_task(producer(i, iterator))
        for i, iterator in enumerate(iterators)
    ]

    async def consumer():
        while not all(finished) or not queue.empty():
            item = await queue.get()
            if isinstance(item, Exception):
                raise item
            yield item
        await asyncio.gather(*_tasks)

    return consumer()


238
def get_ip() -> str:
239
240
241
242
243
244
    host_ip = os.environ.get("HOST_IP")
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

245
    # try ipv4
246
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
247
    try:
248
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
249
        return s.getsockname()[0]
250
251
252
253
254
    except Exception:
        pass

    # try ipv6
    try:
255
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
256
257
258
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
259
        return s.getsockname()[0]
260
261
262
263
264
265
266
267
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
        "The value can be set by the environment variable HOST_IP.",
        stacklevel=2)
    return "0.0.0.0"
268
269


270
def get_distributed_init_method(ip: str, port: int) -> str:
271
272
273
    # Brackets are not permitted in ipv4 addresses,
    # see https://github.com/python/cpython/issues/103848
    return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
274
275


276
def get_open_port() -> int:
277
278
279
280
281
282
283
284
285
286
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
287
288


289
290
def update_environment_variables(envs: Dict[str, str]):
    for k, v in envs.items():
291
        if k in os.environ and os.environ[k] != v:
292
293
294
            logger.warning(
                "Overwriting environment variable %s "
                "from '%s' to '%s'", k, os.environ[k], v)
295
        os.environ[k] = v
296
297


298
299
300
301
302
303
304
305
306
307
def chunk_list(lst, chunk_size):
    """Yield successive chunk_size chunks from lst."""
    return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]


def cdiv(a: int, b: int) -> int:
    """Ceiling division."""
    return -(a // -b)


308
@lru_cache(maxsize=None)
309
def get_nvcc_cuda_version() -> Optional[Version]:
310
311
312
    cuda_home = os.environ.get('CUDA_HOME')
    if not cuda_home:
        cuda_home = '/usr/local/cuda'
313
        if os.path.isfile(cuda_home + '/bin/nvcc'):
314
315
316
            logger.info(
                'CUDA_HOME is not found in the environment. '
                'Using %s as CUDA_HOME.', cuda_home)
317
        else:
318
319
            logger.warning('Not found nvcc in %s. Skip cuda version check!',
                           cuda_home)
320
            return None
321
322
323
324
325
326
327
328
    nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
                                          universal_newlines=True)
    output = nvcc_output.split()
    release_idx = output.index("release") + 1
    nvcc_cuda_version = parse(output[release_idx].split(",")[0])
    return nvcc_cuda_version


329
def _generate_random_fp8(
330
331
332
333
334
335
336
    tensor: torch.tensor,
    low: float,
    high: float,
) -> None:
    # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
    # it may occur Inf or NaN if we directly use torch.randint
    # to generate random data for fp8 data.
337
    # For example, s.11111.00 in fp8e5m2 format represents Inf.
338
339
340
341
    #     | E4M3        | E5M2
    #-----|-------------|-------------------
    # Inf | N/A         | s.11111.00
    # NaN | s.1111.111  | s.11111.{01,10,11}
342
    from vllm import _custom_ops as ops
343
344
    tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
    tensor_tmp.uniform_(low, high)
345
    ops.convert_fp8(tensor_tmp, tensor)
346
347
348
349
350
351
352
353
354
355
356
    del tensor_tmp


def create_kv_caches_with_random(
    num_blocks: int,
    block_size: int,
    num_layers: int,
    num_heads: int,
    head_size: int,
    cache_dtype: Optional[Union[str, torch.dtype]],
    model_dtype: Optional[Union[str, torch.dtype]] = None,
357
    seed: int = 0,
358
359
360
    device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    torch.random.manual_seed(seed)
361
362
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
363
364
365
366
367
368
369
370
371
372
373

    if isinstance(cache_dtype, str):
        if cache_dtype == "auto":
            if isinstance(model_dtype, str):
                torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
            elif isinstance(model_dtype, torch.dtype):
                torch_dtype = model_dtype
            else:
                raise ValueError(f"Invalid model dtype: {model_dtype}")
        elif cache_dtype in ["half", "bfloat16", "float"]:
            torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
374
        elif cache_dtype == "fp8":
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
            torch_dtype = torch.uint8
        else:
            raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
    elif isinstance(cache_dtype, torch.dtype):
        torch_dtype = cache_dtype
    else:
        raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")

    scale = head_size**-0.5
    x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
    key_caches = []
    for _ in range(num_layers):
        key_cache = torch.empty(size=key_cache_shape,
                                dtype=torch_dtype,
                                device=device)
391
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
392
            key_cache.uniform_(-scale, scale)
393
394
        elif cache_dtype == 'fp8':
            _generate_random_fp8(key_cache, -scale, scale)
395
396
397
        else:
            raise ValueError(
                f"Does not support key cache of type {cache_dtype}")
398
399
400
401
402
403
404
405
        key_caches.append(key_cache)

    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
    value_caches = []
    for _ in range(num_layers):
        value_cache = torch.empty(size=value_cache_shape,
                                  dtype=torch_dtype,
                                  device=device)
406
        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
407
            value_cache.uniform_(-scale, scale)
408
409
        elif cache_dtype == 'fp8':
            _generate_random_fp8(value_cache, -scale, scale)
410
411
412
        else:
            raise ValueError(
                f"Does not support value cache of type {cache_dtype}")
413
414
        value_caches.append(value_cache)
    return key_caches, value_caches
415
416


417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
@lru_cache
def print_warning_once(msg: str) -> None:
    logger.warning(msg)


@lru_cache(maxsize=None)
def is_pin_memory_available() -> bool:

    if in_wsl():
        # Pinning memory in WSL is not supported.
        # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
        print_warning_once("Using 'pin_memory=False' as WSL is detected. "
                           "This may slow down the performance.")
        return False
    elif is_neuron():
        print_warning_once("Pin memory is not supported on Neuron.")
        return False
434
435
    elif is_cpu():
        return False
436
437
438
439
    return True


class CudaMemoryProfiler:
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460

    def __init__(self, device=None):
        self.device = device

    def current_memory_usage(self) -> float:
        # Return the memory usage in bytes.
        torch.cuda.reset_peak_memory_stats(self.device)
        mem = torch.cuda.max_memory_allocated(self.device)
        return mem

    def __enter__(self):
        self.initial_memory = self.current_memory_usage()
        # This allows us to call methods of the context manager if needed
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.final_memory = self.current_memory_usage()
        self.consumed_memory = self.final_memory - self.initial_memory

        # Force garbage collection
        gc.collect()
461
462


463
def str_to_int_tuple(s: str) -> Tuple[int, ...]:
464
465
466
467
468
469
470
471
472
    """Convert a string to a tuple of integers."""
    try:
        return tuple(map(int, s.split(",")))
    except ValueError as e:
        raise ValueError(
            "String must be a series of integers separated by commas "
            f"(e.g., 1, 2, 3). Given input: {s}") from e


473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
    assert len(x) <= max_len
    return x + [pad] * (max_len - len(x))


def make_tensor_with_pad(
    x: List[List[int]],
    max_len: int,
    pad: int,
    dtype: torch.dtype,
    device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
    """Make a padded tensor of a 2D inputs.

    The padding is applied to the end of each inner list until it reaches
    `max_len`.
    """
    padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
    return torch.tensor(padded_x, dtype=dtype, device=device)


def async_tensor_h2d(
    data: list,
    dtype: torch.dtype,
    target_device: Union[str, torch.device],
    pin_memory: bool,
) -> torch.Tensor:
    """Asynchronously create a tensor and copy it from host to device."""
    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
    return t.to(device=target_device, non_blocking=True)


def maybe_expand_dim(tensor: torch.Tensor,
                     target_dims: int,
                     size: int = 1) -> torch.Tensor:
    """Expand the tensor to the target_dims."""
    if tensor.ndim < target_dims:
        tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
    return tensor
512
513


514
515
def merge_dicts(dict1: Dict[Any, List[Any]],
                dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
516
517
518
519
520
521
522
523
524
525
526
527
528
    """Merge 2 dicts that have key -> List of items.
    
    When a key conflicts, the values in dict1 is prioritized.
    """
    merged_dict = defaultdict(list)

    for key, value in dict1.items():
        merged_dict[key].extend(value)

    for key, value in dict2.items():
        merged_dict[key].extend(value)

    return dict(merged_dict)
529
530
531
532
533
534
535
536


def init_cached_hf_modules():
    """
    Lazy initialization of the Hugging Face modules.
    """
    from transformers.dynamic_module_utils import init_hf_modules
    init_hf_modules()
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560


def nccl_integrity_check(filepath):
    """
    when the library is corrupted, we cannot catch
    the exception in python. it will crash the process.
    instead, we use the exit code of `ldd` to check
    if the library is corrupted. if not, we will return
    the version of the library.
    """
    exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null")
    if exit_code != 0:
        raise RuntimeError(f"Failed to load NCCL library from {filepath} .")
    import ctypes

    nccl = ctypes.CDLL(filepath)
    version = ctypes.c_int()
    nccl.ncclGetVersion.restype = ctypes.c_int
    nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
    result = nccl.ncclGetVersion(ctypes.byref(version))
    assert result == 0
    return version.value


561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
@lru_cache(maxsize=None)
def find_library(lib_name: str) -> str:
    """
    Find the library file in the system.
    `lib_name` is full filename, with both prefix and suffix.
    This function resolves `lib_name` to the full path of the library.
    """
    # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
    # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
    # `/sbin/ldconfig` should exist in all Linux systems.
    # `/sbin/ldconfig` searches the library in the system
    libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
    # each line looks like the following:
    # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
    locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
    # `LD_LIBRARY_PATH` searches the library in the user-defined paths
    env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
    if not locs and env_ld_library_path:
        locs = [
            os.path.join(dir, lib_name)
            for dir in env_ld_library_path.split(":")
            if os.path.exists(os.path.join(dir, lib_name))
        ]
    if not locs:
        raise ValueError(f"Cannot find {lib_name} in the system.")
    return locs[0]


589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
def find_nccl_library():
    so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")

    # check if we have vllm-managed nccl
    vllm_nccl_path = None
    if torch.version.cuda is not None:
        cuda_major = torch.version.cuda.split(".")[0]
        path = os.path.expanduser(
            f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*")
        files = glob.glob(path)
        vllm_nccl_path = files[0] if files else None

    # manually load the nccl library
    if so_file:
        logger.info(
604
605
            "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s",
            so_file)
606
607
    else:
        if torch.version.cuda is not None:
608
            so_file = vllm_nccl_path or find_library("libnccl.so.2")
609
        elif torch.version.hip is not None:
610
            so_file = find_library("librccl.so.1")
611
612
        else:
            raise ValueError("NCCL only supports CUDA and ROCm backends.")
613
        logger.info("Found nccl from library %s", so_file)
614
    return so_file
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630


def enable_trace_function_call_for_thread() -> None:
    """Set up function tracing for the current thread,
    if enabled via the VLLM_TRACE_FUNCTION environment variable
    """

    if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
        tmp_dir = tempfile.gettempdir()
        filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
                    f"_thread_{threading.get_ident()}_"
                    f"at_{datetime.datetime.now()}.log").replace(" ", "_")
        log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
                                filename)
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        enable_trace_function_call(log_path)