cuda_utils.py 5.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""CUDA driver helpers shared by the GMS client and server."""

from __future__ import annotations

import os

10
from gpu_memory_service.common.locks import GrantedLockType
11
12
from gpu_memory_service.common.utils import fail

13
14
15
16
17
18
19
20
21
22
23
24
25
try:
    from cuda.bindings import driver as cuda
except ImportError:
    # Keep import-time collection working in CPU-only environments and let the
    # first real CUDA call fail with a targeted message instead.
    class _MissingCuda:
        def __getattr__(self, name):
            raise RuntimeError(
                "cuda-python is required for GPU Memory Service CUDA operations"
            )

    cuda = _MissingCuda()

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

def cuda_check_result(result: cuda.CUresult, name: str) -> None:
    if result != cuda.CUresult.CUDA_SUCCESS:
        err_result, err_str = cuda.cuGetErrorString(result)
        if err_result == cuda.CUresult.CUDA_SUCCESS and err_str:
            err_msg = err_str.decode() if isinstance(err_str, bytes) else str(err_str)
        else:
            err_msg = str(result)
        fail("fatal CUDA VMM error in %s: %s", name, err_msg)


def cuda_ensure_initialized() -> None:
    (result,) = cuda.cuInit(0)
    cuda_check_result(result, "cuInit")


def cumem_get_allocation_granularity(device: int) -> int:
    """Get VMM allocation granularity for a device.

    Args:
        device: CUDA device index

    Returns:
        Allocation granularity in bytes (typically 2 MiB)
    """
    prop = cuda.CUmemAllocationProp()
    prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
    prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
    prop.location.id = device
    prop.requestedHandleTypes = (
        cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
    )

    result, granularity = cuda.cuMemGetAllocationGranularity(
        prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM
    )
    cuda_check_result(result, "cuMemGetAllocationGranularity")
    return int(granularity)


def cumem_create_tolerate_oom(size: int, device: int) -> tuple[bool, int]:
    prop = cuda.CUmemAllocationProp()
    prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
    prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
    prop.location.id = device
    prop.requestedHandleTypes = (
        cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
    )

    result, handle = cuda.cuMemCreate(size, prop, 0)
    if result == cuda.CUresult.CUDA_SUCCESS:
        return True, int(handle)
    if result == cuda.CUresult.CUDA_ERROR_OUT_OF_MEMORY:
        return False, 0
    cuda_check_result(result, "cuMemCreate")
    return False, 0


def cumem_export_to_shareable_handle(handle: int) -> int:
    result, fd = cuda.cuMemExportToShareableHandle(
        handle,
        cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
        0,
    )
    cuda_check_result(result, "cuMemExportToShareableHandle")
    return int(fd)


def align_to_granularity(size: int, granularity: int) -> int:
    """Align size up to VMM granularity.

    Args:
        size: Size in bytes
        granularity: Allocation granularity

    Returns:
        Aligned size
    """
    return ((size + granularity - 1) // granularity) * granularity


def cumem_import_from_shareable_handle_close_fd(fd: int) -> int:
    try:
        result, handle = cuda.cuMemImportFromShareableHandle(
            fd,
            cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
        )
        cuda_check_result(result, "cuMemImportFromShareableHandle")
        return int(handle)
    finally:
        os.close(fd)


def cumem_address_reserve(size: int, granularity: int) -> int:
    result, va = cuda.cuMemAddressReserve(size, granularity, 0, 0)
    cuda_check_result(result, "cuMemAddressReserve")
    return int(va)


def cumem_address_free(va: int, size: int) -> None:
    (result,) = cuda.cuMemAddressFree(va, size)
    cuda_check_result(result, "cuMemAddressFree")


def cumem_map(va: int, size: int, handle: int) -> None:
    (result,) = cuda.cuMemMap(va, size, 0, handle, 0)
    cuda_check_result(result, "cuMemMap")


def cumem_set_access(va: int, size: int, device: int, access: GrantedLockType) -> None:
    access_desc = cuda.CUmemAccessDesc()
    access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
    access_desc.location.id = device
    access_desc.flags = (
        cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READ
        if access == GrantedLockType.RO
        else cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
    )
    (result,) = cuda.cuMemSetAccess(va, size, [access_desc], 1)
    cuda_check_result(result, "cuMemSetAccess")


def cumem_unmap(va: int, size: int) -> None:
    (result,) = cuda.cuMemUnmap(va, size)
    cuda_check_result(result, "cuMemUnmap")


def cumem_release(handle: int) -> None:
    (result,) = cuda.cuMemRelease(handle)
    cuda_check_result(result, "cuMemRelease")


def cuda_validate_pointer(va: int) -> None:
    result, _ = cuda.cuPointerGetAttribute(
        cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_POINTER, va
    )
    cuda_check_result(result, "cuPointerGetAttribute")


def cuda_synchronize() -> None:
    (result,) = cuda.cuCtxSynchronize()
    cuda_check_result(result, "cuCtxSynchronize")