rpc.py 9.17 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
"""Async GMS RPC transport server."""
5
6
7
8
9
10

from __future__ import annotations

import asyncio
import logging
import os
11
12
13
import select
import socket
from typing import Optional
14
15
16

from gpu_memory_service.common.protocol.messages import (
    ErrorResponse,
17
18
    GetEventHistoryRequest,
    GetRuntimeStateRequest,
19
20
21
22
    HandshakeRequest,
    HandshakeResponse,
)
from gpu_memory_service.common.protocol.wire import recv_message, send_message
23
from gpu_memory_service.common.utils import fail
24

25
26
27
from .allocations import AllocationNotFoundError
from .gms import GMS
from .session import Connection, InvalidTransition, OperationNotAllowed
28
29
30
31

logger = logging.getLogger(__name__)


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def _is_connection_alive(conn: Connection) -> bool:
    if conn.writer.is_closing():
        return False
    if conn.reader.at_eof() or conn.reader.exception() is not None:
        return False
    sock = conn.writer.get_extra_info("socket")
    if sock is None:
        return False
    try:
        fd = sock.fileno()
    except OSError:
        return False
    if fd < 0:
        return False

    flags = select.POLLERR | select.POLLHUP | select.POLLNVAL
    if hasattr(select, "POLLRDHUP"):
        flags |= select.POLLRDHUP
    poller = select.poll()
    poller.register(fd, flags)
    return not poller.poll(0)
53

54
55
56

class GMSRPCServer:
    """Unix-socket transport for the GPU Memory Service."""
57

58
59
60
61
    def __init__(
        self,
        socket_path: str,
        device: int = 0,
62
63
64
        *,
        allocation_retry_interval: float = 0.5,
        allocation_retry_timeout: Optional[float] = None,
65
    ):
66
67
        self.socket_path = socket_path
        self.device = device
68
69
70
71
72
        self._gms = GMS(
            device,
            allocation_retry_interval=allocation_retry_interval,
            allocation_retry_timeout=allocation_retry_timeout,
        )
73
        self._server: Optional[asyncio.Server] = None
74
        logger.info("GMSRPCServer initialized: device=%d", device)
75

76
77
78
    def _prepare_socket_path(self) -> None:
        if not os.path.exists(self.socket_path):
            return
79

80
81
82
83
84
85
86
87
88
        probe = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        try:
            probe.connect(self.socket_path)
        except OSError:
            if os.path.exists(self.socket_path):
                os.unlink(self.socket_path)
            return
        finally:
            probe.close()
89

90
        raise RuntimeError(f"GMS already running at {self.socket_path}")
91
92

    @property
93
94
    def state(self):
        return self._gms.state
95
96

    def is_ready(self) -> bool:
97
        return self._gms.is_ready()
98
99
100
101
102

    async def _handle_connection(
        self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
    ) -> None:
        conn: Optional[Connection] = None
103
        session_id = self._gms.next_session_id()
104
105
106
107
108
        try:
            conn = await self._do_handshake(reader, writer, session_id)
            if conn is None:
                return
            await self._request_loop(conn)
109
110
        except (InvalidTransition, AssertionError) as exc:
            fail("fatal server error", exc_info=exc)
111
        except ConnectionResetError:
112
            logger.debug("Connection reset: %s", session_id)
113
114
        except asyncio.CancelledError:
            raise
115
116
        except Exception as exc:
            fail("fatal server error", exc_info=exc)
117
        finally:
118
119
120
121
            try:
                await self._gms.cleanup_connection(conn)
            except Exception as exc:
                fail("fatal server error", exc_info=exc)
122
123
124
125
126
127
128
129
130
131
132
133
134

    async def _do_handshake(
        self,
        reader: asyncio.StreamReader,
        writer: asyncio.StreamWriter,
        session_id: str,
    ) -> Optional[Connection]:
        try:
            msg, _, recv_buffer = await recv_message(reader, bytearray())
        except Exception:
            logger.exception("Handshake recv error")
            return None

135
136
137
138
139
140
141
        if isinstance(msg, GetRuntimeStateRequest):
            try:
                await send_message(writer, self._gms.get_runtime_state())
            except Exception as exc:
                logger.debug("Runtime-state response send failed: %s", exc)
            finally:
                writer.close()
142
143
            return None

144
145
146
147
148
149
150
        if isinstance(msg, GetEventHistoryRequest):
            try:
                await send_message(writer, self._gms.get_event_history())
            except Exception as exc:
                logger.debug("Event-history response send failed: %s", exc)
            finally:
                writer.close()
151
152
            return None

153
154
155
156
157
158
159
160
161
        if not isinstance(msg, HandshakeRequest):
            try:
                await send_message(
                    writer, ErrorResponse(error="Expected HandshakeRequest")
                )
            except Exception:
                pass
            writer.close()
            return None
162

163
164
165
166
        granted_mode = await self._gms.acquire_lock(
            msg.lock_type,
            msg.timeout_ms,
            session_id,
167
        )
168
        if granted_mode is None:
169
            try:
170
171
172
173
174
                await send_message(
                    writer,
                    HandshakeResponse(success=False, committed=self._gms.committed),
                )
            except Exception:
175
                pass
176
177
            writer.close()
            return None
178

179
180
181
182
183
184
185
186
187
188
189
190
        try:
            conn = Connection(
                reader=reader,
                writer=writer,
                mode=granted_mode,
                session_id=session_id,
                recv_buffer=recv_buffer,
            )
            self._gms.on_connect(conn)
        except Exception:
            await self._gms.cancel_connect(session_id, granted_mode)
            raise
191

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        try:
            await send_message(
                writer,
                HandshakeResponse(
                    success=True,
                    committed=self._gms.committed,
                    granted_lock_type=granted_mode,
                ),
            )
        except Exception as exc:
            logger.warning(
                "Handshake failed after acquiring %s for session %s: %s",
                granted_mode.value,
                session_id,
                exc,
            )
            await self._gms.cleanup_connection(conn)
            return None

        return conn
212
213

    async def _request_loop(self, conn: Connection) -> None:
214
        while True:
215
216
217
218
219
220
221
222
            try:
                msg, _, conn.recv_buffer = await recv_message(
                    conn.reader, conn.recv_buffer
                )
            except ConnectionResetError:
                return
            except asyncio.CancelledError:
                raise
223
224
            except Exception as exc:
                logger.warning("Recv error on session %s: %s", conn.session_id, exc)
225
226
227
228
229
                return

            if msg is None:
                continue

230
            fd = -1
231
            try:
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
                response, fd, should_close = await self._gms.handle_request(
                    conn,
                    msg,
                    lambda: _is_connection_alive(conn),
                )
            except ConnectionAbortedError as exc:
                logger.warning(
                    "Connection lost during %s on session %s: %s",
                    type(msg).__name__,
                    conn.session_id,
                    exc,
                )
                return
            except (
                OperationNotAllowed,
                ValueError,
                TimeoutError,
                AllocationNotFoundError,
            ) as exc:
                logger.warning(
                    "Rejected %s from session %s: %s",
                    type(msg).__name__,
                    conn.session_id,
                    exc,
                )
                try:
                    await send_message(conn.writer, ErrorResponse(error=str(exc)))
                except Exception as send_exc:
                    logger.warning(
                        "Failed to send ErrorResponse for %s on session %s: %s",
                        type(msg).__name__,
                        conn.session_id,
                        send_exc,
                    )
266
                    return
267
268
269
270
271
                continue
            except (InvalidTransition, AssertionError) as exc:
                fail("fatal server error", exc_info=exc)
            except Exception as exc:
                fail("fatal server error", exc_info=exc)
272

273
274
275
276
277
278
279
280
281
282
283
284
285
            try:
                await send_message(conn.writer, response, fd)
            except Exception as exc:
                logger.warning(
                    "Response send failed for %s on session %s: %s",
                    type(msg).__name__,
                    conn.session_id,
                    exc,
                )
                return
            finally:
                if fd >= 0:
                    os.close(fd)
286

287
288
            if should_close:
                return
289

290
291
    async def serve(self) -> None:
        self._prepare_socket_path()
292
        self._server = await asyncio.start_unix_server(
293
294
            self._handle_connection,
            path=self.socket_path,
295
        )
296
297
        logger.info("Server started: %s", self.socket_path)
        await self._server.serve_forever()