"tests/test_vllm_port.py" did not exist on "b18201fe060a3ddcc088f8aea3cf1d7c4b461288"
cuda.py 3.92 KB
Newer Older
1
2
3
4
"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""

5
import os
6
from functools import lru_cache, wraps
7
from typing import Callable, List, Tuple, TypeVar
8
9

import pynvml
10
from typing_extensions import ParamSpec
11

12
13
from vllm.logger import init_logger

14
15
from .interface import Platform, PlatformEnum

16
17
logger = init_logger(__name__)

18
19
20
_P = ParamSpec("_P")
_R = TypeVar("_R")

21
22
23
24
25
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA

26

27
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
28
29

    @wraps(fn)
30
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
31
32
33
34
35
36
37
38
39
        pynvml.nvmlInit()
        try:
            return fn(*args, **kwargs)
        finally:
            pynvml.nvmlShutdown()

    return wrapper


40
41
42
43
44
45
46
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
    handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
    return pynvml.nvmlDeviceGetCudaComputeCapability(handle)


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_name(device_id: int = 0) -> str:
    handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
    return pynvml.nvmlDeviceGetName(handle)


@with_nvml_context
def warn_if_different_devices():
    device_ids: int = pynvml.nvmlDeviceGetCount()
    if device_ids > 1:
        device_names = [get_physical_device_name(i) for i in range(device_ids)]
        if len(set(device_names)) > 1 and os.environ.get(
                "CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
            logger.warning(
                "Detected different devices in the system: \n%s\nPlease"
                " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
                "avoid unexpected behavior.", "\n".join(device_names))


try:
    from sphinx.ext.autodoc.mock import _MockModule

    if not isinstance(pynvml, _MockModule):
        warn_if_different_devices()
except ModuleNotFoundError:
    warn_if_different_devices()


76
77
78
79
def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        physical_device_id = device_ids[device_id]
80
        return int(physical_device_id)
81
    else:
82
        return device_id
83
84


85
86
87
88
89
class CudaPlatform(Platform):
    _enum = PlatformEnum.CUDA

    @staticmethod
    def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
90
91
        physical_device_id = device_id_to_physical_device_id(device_id)
        return get_physical_device_capability(physical_device_id)
92

93
94
95
96
97
    @staticmethod
    def get_device_name(device_id: int = 0) -> str:
        physical_device_id = device_id_to_physical_device_id(device_id)
        return get_physical_device_name(physical_device_id)

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    @staticmethod
    @with_nvml_context
    def is_full_nvlink(physical_device_ids: List[int]) -> bool:
        """
        query if the set of gpus are fully connected by nvlink (1 hop)
        """
        handles = [
            pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        p2p_status = pynvml.nvmlDeviceGetP2PStatus(
                            handle, peer_handle,
                            pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
                        if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                            return False
                    except pynvml.NVMLError as error:
                        logger.error(
                            "NVLink detection failed. This is normal if your"
                            " machine has no NVLink equipped.",
                            exc_info=error)
                        return False
        return True