allocations.py 6.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Server-side CUDA allocation store."""

from __future__ import annotations

import asyncio
import logging
import time
from dataclasses import dataclass
from typing import Callable, Optional
from uuid import uuid4

from gpu_memory_service.common.cuda_utils import (
    align_to_granularity,
    cuda_ensure_initialized,
    cumem_create_tolerate_oom,
    cumem_export_to_shareable_handle,
    cumem_get_allocation_granularity,
    cumem_release,
)

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class AllocationInfo:
    allocation_id: str
    size: int
    aligned_size: int
    handle: int
    tag: str
    layout_slot: int
    created_at: float


class AllocationNotFoundError(Exception):
    """Raised when an allocation_id doesn't exist."""


class GMSAllocationManager:
    """Server-side CUDA VMM allocation store."""

    def __init__(
        self,
        device: int = 0,
        *,
        allocation_retry_interval: float = 0.5,
        allocation_retry_timeout: Optional[float] = None,
    ):
        if allocation_retry_interval <= 0:
            raise ValueError(
                f"allocation_retry_interval must be > 0, got {allocation_retry_interval}"
            )
        if allocation_retry_timeout is not None and allocation_retry_timeout <= 0:
            raise ValueError(
                f"allocation_retry_timeout must be > 0 when set, got {allocation_retry_timeout}"
            )

        self._device = device
        self._allocations: dict[str, AllocationInfo] = {}
        self._next_layout_slot = 0
        cuda_ensure_initialized()
        self._granularity = cumem_get_allocation_granularity(device)
        self._allocation_retry_interval = allocation_retry_interval
        self._allocation_retry_timeout = allocation_retry_timeout
        logger.info(
            "GMSAllocationManager initialized: device=%d, granularity=%d, "
            "alloc_retry_interval=%.3f, alloc_retry_timeout=%s",
            device,
            self._granularity,
            self._allocation_retry_interval,
            (
                f"{self._allocation_retry_timeout:.3f}"
                if self._allocation_retry_timeout is not None
                else "none"
            ),
        )

    @property
    def device(self) -> int:
        return self._device

    @property
    def allocation_count(self) -> int:
        return len(self._allocations)

    async def allocate(
        self,
        size: int,
        tag: str = "default",
        is_connected: Optional[Callable[[], bool]] = None,
        on_oom: Optional[Callable[[], None]] = None,
    ) -> AllocationInfo:
        if size <= 0:
            raise ValueError(f"size must be > 0, got {size}")

        aligned_size = align_to_granularity(size, self._granularity)
        started_at = time.monotonic()
        reported_oom = False
        while True:
            if is_connected is not None and not is_connected():
                raise ConnectionAbortedError(
                    "RW client disconnected during allocation retry"
                )

            allocated, handle = cumem_create_tolerate_oom(aligned_size, self._device)
            if allocated:
                break

            if on_oom is not None and not reported_oom:
                on_oom()
                reported_oom = True

            if self._allocation_retry_timeout is not None:
                waited = time.monotonic() - started_at
                if waited >= self._allocation_retry_timeout:
                    raise TimeoutError(
                        "Timed out waiting for GPU memory: "
                        f"requested_size={size}, aligned_size={aligned_size}, "
                        f"tag={tag}, waited_sec={waited:.3f}"
                    )

            logger.warning(
                "cuMemCreate OOM for aligned_size=%d bytes, tag=%s; retrying in %.3fs",
                aligned_size,
                tag,
                self._allocation_retry_interval,
            )
            await asyncio.sleep(self._allocation_retry_interval)

        info = AllocationInfo(
            allocation_id=str(uuid4()),
            size=size,
            aligned_size=aligned_size,
            handle=int(handle),
            tag=tag,
            layout_slot=self._next_layout_slot,
            created_at=time.time(),
        )
        self._next_layout_slot = info.layout_slot + 1
        self._allocations[info.allocation_id] = info
        logger.debug(
            "Allocated %s: size=%d, aligned=%d, tag=%s, slot=%d",
            info.allocation_id,
            size,
            aligned_size,
            tag,
            info.layout_slot,
        )
        return info

    def export_allocation(self, allocation_id: str) -> int:
        return cumem_export_to_shareable_handle(
            self.get_allocation(allocation_id).handle
        )

    def free_allocation(self, allocation_id: str) -> bool:
        info = self._allocations.get(allocation_id)
        if info is None:
            return False
        cumem_release(info.handle)
        self._allocations.pop(allocation_id, None)
        logger.debug("Freed allocation: %s", allocation_id)
        return True

    def clear_all(self) -> int:
        allocation_ids = list(self._allocations)
        for allocation_id in allocation_ids:
            info = self._allocations[allocation_id]
            cumem_release(info.handle)
            self._allocations.pop(allocation_id, None)
        if allocation_ids:
            logger.info("Cleared %d allocations", len(allocation_ids))
        self._next_layout_slot = 0
        return len(allocation_ids)

    def get_allocation(self, allocation_id: str) -> AllocationInfo:
        info = self._allocations.get(allocation_id)
        if info is None:
            raise AllocationNotFoundError(f"Unknown allocation: {allocation_id}")
        return info

    def list_allocations(self, tag: Optional[str] = None) -> list[AllocationInfo]:
        allocations = list(self._allocations.values())
        allocations.sort(key=lambda info: info.layout_slot)
        if tag is None:
            return allocations
        return [info for info in allocations if info.tag == tag]