allocations.py 6.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Server-side CUDA allocation store."""

from __future__ import annotations

import asyncio
import logging
10
import os
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import time
from dataclasses import dataclass
from typing import Callable, Optional
from uuid import uuid4

from gpu_memory_service.common.cuda_utils import (
    align_to_granularity,
    cuda_ensure_initialized,
    cumem_create_tolerate_oom,
    cumem_export_to_shareable_handle,
    cumem_get_allocation_granularity,
    cumem_release,
)

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class AllocationInfo:
    allocation_id: str
    size: int
    aligned_size: int
    handle: int
34
    export_fd: int
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    tag: str
    layout_slot: int
    created_at: float


class AllocationNotFoundError(Exception):
    """Raised when an allocation_id doesn't exist."""


class GMSAllocationManager:
    """Server-side CUDA VMM allocation store."""

    def __init__(
        self,
        device: int = 0,
        *,
        allocation_retry_interval: float = 0.5,
        allocation_retry_timeout: Optional[float] = None,
    ):
        if allocation_retry_interval <= 0:
            raise ValueError(
                f"allocation_retry_interval must be > 0, got {allocation_retry_interval}"
            )
        if allocation_retry_timeout is not None and allocation_retry_timeout <= 0:
            raise ValueError(
                f"allocation_retry_timeout must be > 0 when set, got {allocation_retry_timeout}"
            )

        self._device = device
        self._allocations: dict[str, AllocationInfo] = {}
        self._next_layout_slot = 0
        cuda_ensure_initialized()
        self._granularity = cumem_get_allocation_granularity(device)
        self._allocation_retry_interval = allocation_retry_interval
        self._allocation_retry_timeout = allocation_retry_timeout
        logger.info(
            "GMSAllocationManager initialized: device=%d, granularity=%d, "
            "alloc_retry_interval=%.3f, alloc_retry_timeout=%s",
            device,
            self._granularity,
            self._allocation_retry_interval,
            (
                f"{self._allocation_retry_timeout:.3f}"
                if self._allocation_retry_timeout is not None
                else "none"
            ),
        )

    @property
    def device(self) -> int:
        return self._device

    @property
    def allocation_count(self) -> int:
        return len(self._allocations)

    async def allocate(
        self,
        size: int,
        tag: str = "default",
        is_connected: Optional[Callable[[], bool]] = None,
        on_oom: Optional[Callable[[], None]] = None,
    ) -> AllocationInfo:
        if size <= 0:
            raise ValueError(f"size must be > 0, got {size}")

        aligned_size = align_to_granularity(size, self._granularity)
        started_at = time.monotonic()
        reported_oom = False
        while True:
            if is_connected is not None and not is_connected():
                raise ConnectionAbortedError(
                    "RW client disconnected during allocation retry"
                )

            allocated, handle = cumem_create_tolerate_oom(aligned_size, self._device)
            if allocated:
                break

            if on_oom is not None and not reported_oom:
                on_oom()
                reported_oom = True

            if self._allocation_retry_timeout is not None:
                waited = time.monotonic() - started_at
                if waited >= self._allocation_retry_timeout:
                    raise TimeoutError(
                        "Timed out waiting for GPU memory: "
                        f"requested_size={size}, aligned_size={aligned_size}, "
                        f"tag={tag}, waited_sec={waited:.3f}"
                    )

            logger.warning(
                "cuMemCreate OOM for aligned_size=%d bytes, tag=%s; retrying in %.3fs",
                aligned_size,
                tag,
                self._allocation_retry_interval,
            )
            await asyncio.sleep(self._allocation_retry_interval)

135
        export_fd = int(cumem_export_to_shareable_handle(int(handle)))
136
137
138
139
140
        info = AllocationInfo(
            allocation_id=str(uuid4()),
            size=size,
            aligned_size=aligned_size,
            handle=int(handle),
141
            export_fd=export_fd,
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
            tag=tag,
            layout_slot=self._next_layout_slot,
            created_at=time.time(),
        )
        self._next_layout_slot = info.layout_slot + 1
        self._allocations[info.allocation_id] = info
        logger.debug(
            "Allocated %s: size=%d, aligned=%d, tag=%s, slot=%d",
            info.allocation_id,
            size,
            aligned_size,
            tag,
            info.layout_slot,
        )
        return info

    def export_allocation(self, allocation_id: str) -> int:
159
160
        info = self.get_allocation(allocation_id)
        return os.dup(info.export_fd)
161
162
163
164
165

    def free_allocation(self, allocation_id: str) -> bool:
        info = self._allocations.get(allocation_id)
        if info is None:
            return False
166
        os.close(info.export_fd)
167
168
169
170
171
172
173
174
175
        cumem_release(info.handle)
        self._allocations.pop(allocation_id, None)
        logger.debug("Freed allocation: %s", allocation_id)
        return True

    def clear_all(self) -> int:
        allocation_ids = list(self._allocations)
        for allocation_id in allocation_ids:
            info = self._allocations[allocation_id]
176
            os.close(info.export_fd)
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            cumem_release(info.handle)
            self._allocations.pop(allocation_id, None)
        if allocation_ids:
            logger.info("Cleared %d allocations", len(allocation_ids))
        self._next_layout_slot = 0
        return len(allocation_ids)

    def get_allocation(self, allocation_id: str) -> AllocationInfo:
        info = self._allocations.get(allocation_id)
        if info is None:
            raise AllocationNotFoundError(f"Unknown allocation: {allocation_id}")
        return info

    def list_allocations(self, tag: Optional[str] = None) -> list[AllocationInfo]:
        allocations = list(self._allocations.values())
        allocations.sort(key=lambda info: info.layout_slot)
        if tag is None:
            return allocations
        return [info for info in allocations if info.tag == tag]