cuda.py 9.78 KB
Newer Older
1
2
3
4
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

5
import os
6
from functools import lru_cache, wraps
7
8
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
                    Union)
9
10

import pynvml
11
import torch
12
from typing_extensions import ParamSpec
13

14
15
# import custom ops, trigger op registration
import vllm._C  # noqa
16
import vllm.envs as envs
17
18
from vllm.logger import init_logger

19
from .interface import DeviceCapability, Platform, PlatformEnum
20

21
22
23
24
25
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

26
27
logger = init_logger(__name__)

28
29
30
_P = ParamSpec("_P")
_R = TypeVar("_R")

31
32
33
if pynvml.__file__.endswith("__init__.py"):
    logger.warning(
        "You are using a deprecated `pynvml` package. Please install"
34
35
36
        " `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
        " When both of them are installed, `pynvml` will take precedence"
        " and cause errors. See https://pypi.org/project/pynvml "
37
38
        "for more information.")

39
40
41
42
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        if device_ids == [""]:
            msg = (
                "CUDA_VISIBLE_DEVICES is set to empty string, which means"
                " GPU support is disabled. If you are using ray, please unset"
                " the environment variable `CUDA_VISIBLE_DEVICES` inside the"
                " worker/actor. "
                "Check https://github.com/vllm-project/vllm/issues/8402 for"
                " more information.")
            raise RuntimeError(msg)
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id
60

61

62
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
63
64

    @wraps(fn)
65
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
66
67
68
69
70
71
72
73
74
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


75
76
class CudaPlatformBase(Platform):
    _enum = PlatformEnum.CUDA
77
    device_name: str = "cuda"
78
79
    device_type: str = "cuda"
    dispatch_key: str = "CUDA"
80

81
    @classmethod
82
83
84
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
85
        raise NotImplementedError
86

87
88
89
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError
90

91
92
93
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError
94

95
96
97
98
99
100
101
102
103
104
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        if enforce_eager:
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

105
106
107
    @classmethod
    def is_full_nvlink(cls, device_ids: List[int]) -> bool:
        raise NotImplementedError
108

109
110
111
    @classmethod
    def log_warnings(cls):
        pass
112

113
114
115
116
    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        parallel_config = vllm_config.parallel_config
        scheduler_config = vllm_config.scheduler_config
117

118
119
        if parallel_config.worker_cls == "auto":
            if scheduler_config.is_multi_step:
120
121
122
123
124
                if envs.VLLM_USE_V1:
                    raise NotImplementedError
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.multi_step_worker.MultiStepWorker"
125
            elif vllm_config.speculative_config:
126
127
128
129
130
131
132
                if envs.VLLM_USE_V1:
                    raise NotImplementedError
                else:
                    parallel_config.worker_cls = \
                        "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                    parallel_config.sd_worker_cls = \
                        "vllm.worker.worker.Worker"
133
            else:
134
135
136
137
138
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                            "vllm.v1.worker.gpu_worker.Worker"
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
139

140
141
142
143
        cache_config = vllm_config.cache_config
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

144

145
146
147
148
149
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
150

151
    @classmethod
152
153
    @lru_cache(maxsize=8)
    @with_nvml_context
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
        try:
            physical_device_id = device_id_to_physical_device_id(device_id)
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
            major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
            return DeviceCapability(major=major, minor=minor)
        except RuntimeError:
            return None

    @classmethod
    @lru_cache(maxsize=8)
    @with_nvml_context
    def has_device_capability(
        cls,
        capability: Union[Tuple[int, int], int],
        device_id: int = 0,
    ) -> bool:
        try:
            return super().has_device_capability(capability, device_id)
        except RuntimeError:
            return False
177

178
    @classmethod
179
180
    @lru_cache(maxsize=8)
    @with_nvml_context
181
    def get_device_name(cls, device_id: int = 0) -> str:
182
        physical_device_id = device_id_to_physical_device_id(device_id)
183
        return cls._get_physical_device_name(physical_device_id)
184

185
    @classmethod
186
187
    @lru_cache(maxsize=8)
    @with_nvml_context
188
189
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        physical_device_id = device_id_to_physical_device_id(device_id)
190
191
        handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
        return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
192

193
    @classmethod
194
    @with_nvml_context
195
    def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
196
197
198
199
200
201
202
203
204
205
206
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
207
208
209
210
                            handle,
                            peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
                        )
211
212
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
213
214
                    except pynvml.NVMLError:
                        logger.exception(
215
216
                            "NVLink detection failed. This is normal if"
                            " your machine has no NVLink equipped.")
217
218
                        return False
        return True
219
220

    @classmethod
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    def _get_physical_device_name(cls, device_id: int = 0) -> str:
        handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
        return pynvml.nvmlDeviceGetName(handle)

    @classmethod
    @with_nvml_context
    def log_warnings(cls):
        device_ids: int = pynvml.nvmlDeviceGetCount()
        if device_ids > 1:
            device_names = [
                cls._get_physical_device_name(i) for i in range(device_ids)
            ]
            if (len(set(device_names)) > 1
                    and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
                logger.warning(
                    "Detected different devices in the system: \n%s\nPlease"
                    " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                    "avoid unexpected behavior.",
                    "\n".join(device_names),
                )


class NonNvmlCudaPlatform(CudaPlatformBase):

    @classmethod
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)

    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory

    @classmethod
    def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
        logger.exception(
            "NVLink detection not possible, as context support was"
            " not found. Assuming no NVLink available.")
        return False


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
    try:
        pynvml.nvmlInit()
        nvml_available = True
    except Exception:
        # On Jetson, NVML is not supported.
        nvml_available = False
finally:
    if nvml_available:
        pynvml.nvmlShutdown()

CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform

try:
    from sphinx.ext.autodoc.mock import _MockModule

    if not isinstance(pynvml, _MockModule):
        CudaPlatform.log_warnings()
except ModuleNotFoundError:
289
    CudaPlatform.log_warnings()