cuda_wrapper.py 7.31 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""

import ctypes
from dataclasses import dataclass
10
from typing import Any
11
12
13
14

# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch  # noqa

15
import vllm.envs as envs
16
from vllm.logger import init_logger
17
from vllm.platforms import current_platform
18
from vllm.utils.system_utils import find_loaded_library
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

logger = init_logger(__name__)

# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html

cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int


class cudaIpcMemHandle_t(ctypes.Structure):
    _fields_ = [("internal", ctypes.c_byte * 128)]


@dataclass
class Function:
    name: str
    restype: Any
38
    argtypes: list[Any]
39
40
41
42
43
44
45
46
47
48
49
50
51


class CudaRTLibrary:
    exported_functions = [
        # ​cudaError_t cudaSetDevice ( int  device )
        Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
        # cudaError_t 	cudaDeviceSynchronize ( void )
        Function("cudaDeviceSynchronize", cudaError_t, []),
        # ​cudaError_t cudaDeviceReset ( void )
        Function("cudaDeviceReset", cudaError_t, []),
        # const char* 	cudaGetErrorString ( cudaError_t error )
        Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
        # ​cudaError_t 	cudaMalloc ( void** devPtr, size_t size )
52
53
54
55
56
        Function(
            "cudaMalloc",
            cudaError_t,
            [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
        ),
57
58
59
        # ​cudaError_t 	cudaFree ( void* devPtr )
        Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
        # ​cudaError_t cudaMemset ( void* devPtr, int  value, size_t count )
60
61
62
        Function(
            "cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
        ),
63
        # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
64
65
66
67
68
        Function(
            "cudaMemcpy",
            cudaError_t,
            [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
        ),
69
        # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
70
71
72
73
74
        Function(
            "cudaIpcGetMemHandle",
            cudaError_t,
            [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
        ),
75
        # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int  flags ) # noqa
76
77
78
79
80
        Function(
            "cudaIpcOpenMemHandle",
            cudaError_t,
            [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],
        ),
81
82
    ]

83
84
85
86
87
88
89
90
91
92
93
94
95
96
    # https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Runtime_API_functions_supported_by_HIP.html # noqa
    cuda_to_hip_mapping = {
        "cudaSetDevice": "hipSetDevice",
        "cudaDeviceSynchronize": "hipDeviceSynchronize",
        "cudaDeviceReset": "hipDeviceReset",
        "cudaGetErrorString": "hipGetErrorString",
        "cudaMalloc": "hipMalloc",
        "cudaFree": "hipFree",
        "cudaMemset": "hipMemset",
        "cudaMemcpy": "hipMemcpy",
        "cudaIpcGetMemHandle": "hipIpcGetMemHandle",
        "cudaIpcOpenMemHandle": "hipIpcOpenMemHandle",
    }

97
98
    # class attribute to store the mapping from the path to the library
    # to avoid loading the same library multiple times
99
    path_to_library_cache: dict[str, Any] = {}
100
101
102

    # class attribute to store the mapping from library path
    #  to the corresponding dictionary
103
    path_to_dict_mapping: dict[str, dict[str, Any]] = {}
104

105
    def __init__(self, so_file: str | None = None):
106
        if so_file is None:
107
            so_file = find_loaded_library("libcudart")
108
            if so_file is None:
109
110
111
112
113
114
115
                # libcudart is not loaded in the current process, try hip
                so_file = find_loaded_library("libamdhip64")
                # should be safe to assume now that we are using ROCm
                # as the following assertion should error out if the
                # libhiprtc library is also not loaded
                if so_file is None:
                    so_file = envs.VLLM_CUDART_SO_PATH  # fallback to env var
116
117
118
119
            assert so_file is not None, (
                "libcudart is not loaded in the current process, "
                "try setting VLLM_CUDART_SO_PATH"
            )
120
121
122
123
124
125
126
127
        if so_file not in CudaRTLibrary.path_to_library_cache:
            lib = ctypes.CDLL(so_file)
            CudaRTLibrary.path_to_library_cache[so_file] = lib
        self.lib = CudaRTLibrary.path_to_library_cache[so_file]

        if so_file not in CudaRTLibrary.path_to_dict_mapping:
            _funcs = {}
            for func in CudaRTLibrary.exported_functions:
128
129
130
131
132
133
                f = getattr(
                    self.lib,
                    CudaRTLibrary.cuda_to_hip_mapping[func.name]
                    if current_platform.is_rocm()
                    else func.name,
                )
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
                f.restype = func.restype
                f.argtypes = func.argtypes
                _funcs[func.name] = f
            CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
        self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]

    def CUDART_CHECK(self, result: cudaError_t) -> None:
        if result != 0:
            error_str = self.cudaGetErrorString(result)
            raise RuntimeError(f"CUDART error: {error_str}")

    def cudaGetErrorString(self, error: cudaError_t) -> str:
        return self.funcs["cudaGetErrorString"](error).decode("utf-8")

    def cudaSetDevice(self, device: int) -> None:
        self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))

    def cudaDeviceSynchronize(self) -> None:
        self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())

    def cudaDeviceReset(self) -> None:
        self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())

    def cudaMalloc(self, size: int) -> ctypes.c_void_p:
        devPtr = ctypes.c_void_p()
        self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
        return devPtr

    def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
        self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))

165
    def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
166
167
        self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))

168
169
170
    def cudaMemcpy(
        self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int
    ) -> None:
171
172
173
174
        cudaMemcpyDefault = 4
        kind = cudaMemcpyDefault
        self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))

175
    def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
176
        handle = cudaIpcMemHandle_t()
177
178
179
        self.CUDART_CHECK(
            self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)
        )
180
181
        return handle

182
    def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
183
184
        cudaIpcMemLazyEnablePeerAccess = 1
        devPtr = ctypes.c_void_p()
185
186
187
188
189
        self.CUDART_CHECK(
            self.funcs["cudaIpcOpenMemHandle"](
                ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess
            )
        )
190
        return devPtr