cuda.py 8.2 KB
Newer Older
1
2
3
4
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

5
import os
6
from functools import lru_cache, wraps
7
from typing import TYPE_CHECKING, Callable, List, TypeVar
8
9

import pynvml
10
import torch
11
from typing_extensions import ParamSpec
12

13
14
# import custom ops, trigger op registration
import vllm._C  # noqa
15
16
from vllm.logger import init_logger

17
from .interface import DeviceCapability, Platform, PlatformEnum
18

19
20
21
22
23
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

24
25
logger = init_logger(__name__)

26
27
28
_P = ParamSpec("_P")
_R = TypeVar("_R")

29
30
31
if pynvml.__file__.endswith("__init__.py"):
    logger.warning(
        "You are using a deprecated `pynvml` package. Please install"
32
33
34
        " `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
        " When both of them are installed, `pynvml` will take precedence"
        " and cause errors. See https://pypi.org/project/pynvml "
35
36
        "for more information.")

37
38
39
40
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        if device_ids == [""]:
            msg = (
                "CUDA_VISIBLE_DEVICES is set to empty string, which means"
                " GPU support is disabled. If you are using ray, please unset"
                " the environment variable `CUDA_VISIBLE_DEVICES` inside the"
                " worker/actor. "
                "Check https://github.com/vllm-project/vllm/issues/8402 for"
                " more information.")
            raise RuntimeError(msg)
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id
58

59

60
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
61
62

    @wraps(fn)
63
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
64
65
66
67
68
69
70
71
72
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


73
74
75
76
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
77

78
79
80
    @classmethod
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        raise NotImplementedError
81

82
83
84
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
85

86
87
88
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
89

90
91
92
    @classmethod
    def is_full_nvlink(cls, device_ids: List[int]) -> bool:
        raise NotImplementedError
93

94
95
96
    @classmethod
    def log_warnings(cls):
        pass
97

98
99
100
101
102
103
104
105
106
107
108
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
                parallel_config.worker_cls = \
                    "vllm.worker.multi_step_worker.MultiStepWorker"
            elif vllm_config.speculative_config:
                parallel_config.worker_cls = \
                    "vllm.spec_decode.spec_decode_worker.create_spec_worker"
109
110
                parallel_config.sd_worker_cls = \
                    "vllm.worker.worker.Worker"
111
112
            else:
                parallel_config.worker_cls = "vllm.worker.worker.Worker"
113
114


115
116
117
118
119
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
120

121
    @classmethod
122
123
    @lru_cache(maxsize=8)
    @with_nvml_context
124
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
125
        physical_device_id = device_id_to_physical_device_id(device_id)
126
127
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
128
        return DeviceCapability(major=major, minor=minor)
129

130
    @classmethod
131
132
    @lru_cache(maxsize=8)
    @with_nvml_context
133
    def get_device_name(cls, device_id: int = 0) -> str:
134
        physical_device_id = device_id_to_physical_device_id(device_id)
135
        return cls._get_physical_device_name(physical_device_id)
136

137
    @classmethod
138
139
    @lru_cache(maxsize=8)
    @with_nvml_context
140
141
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        physical_device_id = device_id_to_physical_device_id(device_id)
142
143
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
144

145
    @classmethod
146
    @with_nvml_context
147
    def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
148
149
150
151
152
153
154
155
156
157
158
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
159
160
161
162
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
163
164
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
165
166
                    except pynvml.NVMLError:
                        logger.exception(
167
168
                            "NVLink detection failed. This is normal if"
                            " your machine has no NVLink equipped.")
169
170
                        return False
        return True
171
172

    @classmethod
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
            device_names = [
                cls._get_physical_device_name(i) for i in range(device_ids)
            ]
            if (len(set(device_names)) > 1
                    and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
                logger.warning(
                    "Detected different devices in the system: \n%s\nPlease"
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
                    "\n".join(device_names),
                )


class NonNvmlCudaPlatform(CudaPlatformBase):

    @classmethod
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
    def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
        logger.exception(
            "NVLink detection not possible, as context support was"
            " not found. Assuming no NVLink available.")
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

try:
    from sphinx.ext.autodoc.mock import _MockModule

    if not isinstance(pynvml, _MockModule):
        CudaPlatform.log_warnings()
except ModuleNotFoundError:
241
    CudaPlatform.log_warnings()