cuda.py 4.96 KB
Newer Older
1
2
3
4
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

5
import os
6
from functools import lru_cache, wraps
7
from typing import Callable, List, Tuple, TypeVar
8
9

import pynvml
10
from typing_extensions import ParamSpec
11

12
13
from vllm.logger import init_logger

14
from .interface import DeviceCapability, Platform, PlatformEnum
15

16
17
logger = init_logger(__name__)

18
19
20
_P = ParamSpec("_P")
_R = TypeVar("_R")

21
22
23
if pynvml.__file__.endswith("__init__.py"):
    logger.warning(
        "You are using a deprecated `pynvml` package. Please install"
24
25
26
        " `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
        " When both of them are installed, `pynvml` will take precedence"
        " and cause errors. See https://pypi.org/project/pynvml "
27
28
        "for more information.")

29
30
31
32
33
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA

34

35
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
36
37

    @wraps(fn)
38
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
39
40
41
42
43
44
45
46
47
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


48
49
50
51
52
53
54
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
    handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
    return pynvml.nvmlDeviceGetCudaComputeCapability(handle)


55
56
57
58
59
60
61
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_name(device_id: int = 0) -> str:
    handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
    return pynvml.nvmlDeviceGetName(handle)


62
63
64
65
66
67
68
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_total_memory(device_id: int = 0) -> int:
    handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
    return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@with_nvml_context
def warn_if_different_devices():
    device_ids: int = pynvml.nvmlDeviceGetCount()
    if device_ids > 1:
        device_names = [get_physical_device_name(i) for i in range(device_ids)]
        if len(set(device_names)) > 1 and os.environ.get(
                "CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
            logger.warning(
                "Detected different devices in the system: \n%s\nPlease"
                " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                "avoid unexpected behavior.", "\n".join(device_names))


try:
    from sphinx.ext.autodoc.mock import _MockModule

    if not isinstance(pynvml, _MockModule):
        warn_if_different_devices()
except ModuleNotFoundError:
    warn_if_different_devices()


91
92
93
def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
94
95
96
        if device_ids == [""]:
            raise RuntimeError("CUDA_VISIBLE_DEVICES is set to empty string,"
                               " which means GPU support is disabled.")
97
        physical_device_id = device_ids[device_id]
98
        return int(physical_device_id)
99
    else:
100
        return device_id
101
102


103
104
105
class CudaPlatform(Platform):
    _enum = PlatformEnum.CUDA

106
107
    @classmethod
    def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
108
        physical_device_id = device_id_to_physical_device_id(device_id)
109
110
        major, minor = get_physical_device_capability(physical_device_id)
        return DeviceCapability(major=major, minor=minor)
111

112
113
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
114
115
116
        physical_device_id = device_id_to_physical_device_id(device_id)
        return get_physical_device_name(physical_device_id)

117
118
119
120
121
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        physical_device_id = device_id_to_physical_device_id(device_id)
        return get_physical_device_total_memory(physical_device_id)

122
    @classmethod
123
    @with_nvml_context
124
    def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
                            handle, peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
140
141
                    except pynvml.NVMLError:
                        logger.exception(
142
                            "NVLink detection failed. This is normal if your"
143
                            " machine has no NVLink equipped.")
144
145
                        return False
        return True