"examples/backends/vllm/deploy/agg_router.yaml" did not exist on "fe718fd29545dfdaf971c73ffafe3ccb06a25899"
rpc.py 9.19 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
"""Async GMS RPC transport server."""
5
6
7
8
9
10

from __future__ import annotations

import asyncio
import logging
import os
11
12
13
import select
import socket
from typing import Optional
14
15
16

from gpu_memory_service.common.protocol.messages import (
    ErrorResponse,
17
18
    GetEventHistoryRequest,
    GetRuntimeStateRequest,
19
20
21
22
    HandshakeRequest,
    HandshakeResponse,
)
from gpu_memory_service.common.protocol.wire import recv_message, send_message
23
from gpu_memory_service.common.utils import fail
24

25
from .allocations import AllocationNotFoundError
26
from .fsm import Connection, InvalidTransition
27
from .gms import GMS
28
from .session import OperationNotAllowed
29
30
31
32

logger = logging.getLogger(__name__)


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def _is_connection_alive(conn: Connection) -> bool:
    if conn.writer.is_closing():
        return False
    if conn.reader.at_eof() or conn.reader.exception() is not None:
        return False
    sock = conn.writer.get_extra_info("socket")
    if sock is None:
        return False
    try:
        fd = sock.fileno()
    except OSError:
        return False
    if fd < 0:
        return False

    flags = select.POLLERR | select.POLLHUP | select.POLLNVAL
    if hasattr(select, "POLLRDHUP"):
        flags |= select.POLLRDHUP
    poller = select.poll()
    poller.register(fd, flags)
    return not poller.poll(0)
54

55
56
57

class GMSRPCServer:
    """Unix-socket transport for the GPU Memory Service."""
58

59
60
61
62
    def __init__(
        self,
        socket_path: str,
        device: int = 0,
63
64
65
        *,
        allocation_retry_interval: float = 0.5,
        allocation_retry_timeout: Optional[float] = None,
66
    ):
67
68
        self.socket_path = socket_path
        self.device = device
69
70
71
72
73
        self._gms = GMS(
            device,
            allocation_retry_interval=allocation_retry_interval,
            allocation_retry_timeout=allocation_retry_timeout,
        )
74
        self._server: Optional[asyncio.Server] = None
75
        logger.info("GMSRPCServer initialized: device=%d", device)
76

77
78
79
    def _prepare_socket_path(self) -> None:
        if not os.path.exists(self.socket_path):
            return
80

81
82
83
84
85
86
87
88
89
        probe = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        try:
            probe.connect(self.socket_path)
        except OSError:
            if os.path.exists(self.socket_path):
                os.unlink(self.socket_path)
            return
        finally:
            probe.close()
90

91
        raise RuntimeError(f"GMS already running at {self.socket_path}")
92
93

    @property
94
95
    def state(self):
        return self._gms.state
96
97

    def is_ready(self) -> bool:
98
        return self._gms.is_ready()
99
100
101
102
103

    async def _handle_connection(
        self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
    ) -> None:
        conn: Optional[Connection] = None
104
        session_id = self._gms.next_session_id()
105
106
107
108
109
        try:
            conn = await self._do_handshake(reader, writer, session_id)
            if conn is None:
                return
            await self._request_loop(conn)
110
111
        except (InvalidTransition, AssertionError) as exc:
            fail("fatal server error", exc_info=exc)
112
        except ConnectionResetError:
113
            logger.debug("Connection reset: %s", session_id)
114
115
        except asyncio.CancelledError:
            raise
116
117
        except Exception as exc:
            fail("fatal server error", exc_info=exc)
118
        finally:
119
120
121
122
            try:
                await self._gms.cleanup_connection(conn)
            except Exception as exc:
                fail("fatal server error", exc_info=exc)
123
124
125
126
127
128
129
130
131
132
133
134
135

    async def _do_handshake(
        self,
        reader: asyncio.StreamReader,
        writer: asyncio.StreamWriter,
        session_id: str,
    ) -> Optional[Connection]:
        try:
            msg, _, recv_buffer = await recv_message(reader, bytearray())
        except Exception:
            logger.exception("Handshake recv error")
            return None

136
137
138
139
140
141
142
        if isinstance(msg, GetRuntimeStateRequest):
            try:
                await send_message(writer, self._gms.get_runtime_state())
            except Exception as exc:
                logger.debug("Runtime-state response send failed: %s", exc)
            finally:
                writer.close()
143
144
            return None

145
146
147
148
149
150
151
        if isinstance(msg, GetEventHistoryRequest):
            try:
                await send_message(writer, self._gms.get_event_history())
            except Exception as exc:
                logger.debug("Event-history response send failed: %s", exc)
            finally:
                writer.close()
152
153
            return None

154
155
156
157
158
159
160
161
162
        if not isinstance(msg, HandshakeRequest):
            try:
                await send_message(
                    writer, ErrorResponse(error="Expected HandshakeRequest")
                )
            except Exception:
                pass
            writer.close()
            return None
163

164
165
166
167
        granted_mode = await self._gms.acquire_lock(
            msg.lock_type,
            msg.timeout_ms,
            session_id,
168
        )
169
        if granted_mode is None:
170
            try:
171
172
173
174
175
                await send_message(
                    writer,
                    HandshakeResponse(success=False, committed=self._gms.committed),
                )
            except Exception:
176
                pass
177
178
            writer.close()
            return None
179

180
181
182
183
184
185
186
187
188
189
190
191
        try:
            conn = Connection(
                reader=reader,
                writer=writer,
                mode=granted_mode,
                session_id=session_id,
                recv_buffer=recv_buffer,
            )
            self._gms.on_connect(conn)
        except Exception:
            await self._gms.cancel_connect(session_id, granted_mode)
            raise
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        try:
            await send_message(
                writer,
                HandshakeResponse(
                    success=True,
                    committed=self._gms.committed,
                    granted_lock_type=granted_mode,
                ),
            )
        except Exception as exc:
            logger.warning(
                "Handshake failed after acquiring %s for session %s: %s",
                granted_mode.value,
                session_id,
                exc,
            )
            await self._gms.cleanup_connection(conn)
            return None

        return conn
213
214

    async def _request_loop(self, conn: Connection) -> None:
215
        while True:
216
217
218
219
220
221
222
223
            try:
                msg, _, conn.recv_buffer = await recv_message(
                    conn.reader, conn.recv_buffer
                )
            except ConnectionResetError:
                return
            except asyncio.CancelledError:
                raise
224
225
            except Exception as exc:
                logger.warning("Recv error on session %s: %s", conn.session_id, exc)
226
227
228
229
230
                return

            if msg is None:
                continue

231
            fd = -1
232
            try:
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
                response, fd, should_close = await self._gms.handle_request(
                    conn,
                    msg,
                    lambda: _is_connection_alive(conn),
                )
            except ConnectionAbortedError as exc:
                logger.warning(
                    "Connection lost during %s on session %s: %s",
                    type(msg).__name__,
                    conn.session_id,
                    exc,
                )
                return
            except (
                OperationNotAllowed,
                ValueError,
                TimeoutError,
                AllocationNotFoundError,
            ) as exc:
                logger.warning(
                    "Rejected %s from session %s: %s",
                    type(msg).__name__,
                    conn.session_id,
                    exc,
                )
                try:
                    await send_message(conn.writer, ErrorResponse(error=str(exc)))
                except Exception as send_exc:
                    logger.warning(
                        "Failed to send ErrorResponse for %s on session %s: %s",
                        type(msg).__name__,
                        conn.session_id,
                        send_exc,
                    )
267
                    return
268
269
270
271
272
                continue
            except (InvalidTransition, AssertionError) as exc:
                fail("fatal server error", exc_info=exc)
            except Exception as exc:
                fail("fatal server error", exc_info=exc)
273

274
275
276
277
278
279
280
281
282
283
284
285
286
            try:
                await send_message(conn.writer, response, fd)
            except Exception as exc:
                logger.warning(
                    "Response send failed for %s on session %s: %s",
                    type(msg).__name__,
                    conn.session_id,
                    exc,
                )
                return
            finally:
                if fd >= 0:
                    os.close(fd)
287

288
289
            if should_close:
                return
290

291
292
    async def serve(self) -> None:
        self._prepare_socket_path()
293
        self._server = await asyncio.start_unix_server(
294
295
            self._handle_connection,
            path=self.socket_path,
296
        )
297
298
        logger.info("Server started: %s", self.socket_path)
        await self._server.serve_forever()