cuda.py 5.54 KB
Newer Older
1
2
3
4
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

5
import os
6
from functools import lru_cache, wraps
7
from typing import Callable, List, Tuple, TypeVar
8
9

import pynvml
10
import torch
11
from typing_extensions import ParamSpec
12

13
14
# import custom ops, trigger op registration
import vllm._C  # noqa
15
16
from vllm.logger import init_logger

17
from .interface import DeviceCapability, Platform, PlatformEnum
18

19
20
logger = init_logger(__name__)

21
22
23
_P = ParamSpec("_P")
_R = TypeVar("_R")

24
25
26
if pynvml.__file__.endswith("__init__.py"):
    logger.warning(
        "You are using a deprecated `pynvml` package. Please install"
27
28
29
        " `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
        " When both of them are installed, `pynvml` will take precedence"
        " and cause errors. See https://pypi.org/project/pynvml "
30
31
        "for more information.")

32
33
34
35
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

36
37
38
39
40
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA

41

42
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
43
44

    @wraps(fn)
45
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
46
47
48
49
50
51
52
53
54
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


55
56
57
58
59
60
61
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
    handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
    return pynvml.nvmlDeviceGetCudaComputeCapability(handle)


62
63
64
65
66
67
68
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_name(device_id: int = 0) -> str:
    handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
    return pynvml.nvmlDeviceGetName(handle)


69
70
71
72
73
74
75
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_total_memory(device_id: int = 0) -> int:
    handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
    return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
@with_nvml_context
def warn_if_different_devices():
    device_ids: int = pynvml.nvmlDeviceGetCount()
    if device_ids > 1:
        device_names = [get_physical_device_name(i) for i in range(device_ids)]
        if len(set(device_names)) > 1 and os.environ.get(
                "CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
            logger.warning(
                "Detected different devices in the system: \n%s\nPlease"
                " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                "avoid unexpected behavior.", "\n".join(device_names))


try:
    from sphinx.ext.autodoc.mock import _MockModule

    if not isinstance(pynvml, _MockModule):
        warn_if_different_devices()
except ModuleNotFoundError:
    warn_if_different_devices()


98
99
100
def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
101
        if device_ids == [""]:
102
103
104
105
106
107
108
109
            msg = (
                "CUDA_VISIBLE_DEVICES is set to empty string, which means"
                " GPU support is disabled. If you are using ray, please unset"
                " the environment variable `CUDA_VISIBLE_DEVICES` inside the"
                " worker/actor. "
                "Check https://github.com/vllm-project/vllm/issues/8402 for"
                " more information.")
            raise RuntimeError(msg)
110
        physical_device_id = device_ids[device_id]
111
        return int(physical_device_id)
112
    else:
113
        return device_id
114
115


116
117
class CudaPlatform(Platform):
    _enum = PlatformEnum.CUDA
118
    device_type: str = "cuda"
119

120
121
    @classmethod
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
122
        physical_device_id = device_id_to_physical_device_id(device_id)
123
124
        major, minor = get_physical_device_capability(physical_device_id)
        return DeviceCapability(major=major, minor=minor)
125

126
127
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
128
129
130
        physical_device_id = device_id_to_physical_device_id(device_id)
        return get_physical_device_name(physical_device_id)

131
132
133
134
135
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        physical_device_id = device_id_to_physical_device_id(device_id)
        return get_physical_device_total_memory(physical_device_id)

136
    @classmethod
137
    @with_nvml_context
138
    def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
                            handle, peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
154
155
                    except pynvml.NVMLError:
                        logger.exception(
156
                            "NVLink detection failed. This is normal if your"
157
                            " machine has no NVLink equipped.")
158
159
                        return False
        return True