session.py 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Server-side connection, FSM, and waiter state."""

from __future__ import annotations

import asyncio
from dataclasses import dataclass, field
from typing import Optional, Set

from gpu_memory_service.common.types import (
    RO_ALLOWED,
    RW_ALLOWED,
    RW_REQUIRED,
    GrantedLockType,
    RequestedLockType,
    ServerState,
    StateEvent,
)


@dataclass(eq=False)
class Connection:
    reader: asyncio.StreamReader
    writer: asyncio.StreamWriter
    mode: GrantedLockType
    session_id: str
    recv_buffer: bytearray = field(default_factory=bytearray)

    def __hash__(self) -> int:
        return hash(self.session_id)

    async def close(self) -> None:
        self.writer.close()
        try:
            await self.writer.wait_closed()
        except Exception:
            pass


class InvalidTransition(Exception):
    """Raised when an invalid state transition is attempted."""


class OperationNotAllowed(Exception):
    """Raised when an operation is not allowed in the current state/mode."""


@dataclass(frozen=True)
class Transition:
    from_states: frozenset[ServerState]
    event: StateEvent
    to_state: Optional[ServerState]
    condition: Optional[str] = None


TRANSITIONS: list[Transition] = [
    Transition(
        from_states=frozenset({ServerState.EMPTY, ServerState.COMMITTED}),
        event=StateEvent.RW_CONNECT,
        to_state=ServerState.RW,
    ),
    Transition(
        from_states=frozenset({ServerState.RW}),
        event=StateEvent.RW_COMMIT,
        to_state=ServerState.COMMITTED,
    ),
    Transition(
        from_states=frozenset({ServerState.RW}),
        event=StateEvent.RW_ABORT,
        to_state=ServerState.EMPTY,
    ),
    Transition(
        from_states=frozenset({ServerState.COMMITTED, ServerState.RO}),
        event=StateEvent.RO_CONNECT,
        to_state=ServerState.RO,
    ),
    Transition(
        from_states=frozenset({ServerState.RO}),
        event=StateEvent.RO_DISCONNECT,
        to_state=ServerState.RO,
        condition="has_remaining_readers",
    ),
    Transition(
        from_states=frozenset({ServerState.RO}),
        event=StateEvent.RO_DISCONNECT,
        to_state=ServerState.COMMITTED,
        condition="is_last_reader",
    ),
]


class GMSLocalFSM:
    """Explicit connection/lock state machine."""

    def __init__(self):
        self._rw_conn: Optional[Connection] = None
        self._ro_conns: Set[Connection] = set()
        self._committed = False

    @property
    def state(self) -> ServerState:
        if self._rw_conn is not None:
            return ServerState.RW
        if self._ro_conns:
            return ServerState.RO
        if self._committed:
            return ServerState.COMMITTED
        return ServerState.EMPTY

    @property
    def rw_conn(self) -> Optional[Connection]:
        return self._rw_conn

    @property
    def ro_conns(self) -> Set[Connection]:
        return self._ro_conns

    @property
    def ro_count(self) -> int:
        return len(self._ro_conns)

    @property
    def committed(self) -> bool:
        return self._committed

    def _has_remaining_readers(self, conn: Connection) -> bool:
        return len(self._ro_conns) > 1 or conn not in self._ro_conns

    def _is_last_reader(self, conn: Connection) -> bool:
        return len(self._ro_conns) == 1 and conn in self._ro_conns

    def _check_condition(self, condition: Optional[str], conn: Connection) -> bool:
        if condition is None:
            return True
        if condition == "has_remaining_readers":
            return self._has_remaining_readers(conn)
        if condition == "is_last_reader":
            return self._is_last_reader(conn)
        raise ValueError(f"Unknown condition: {condition}")

    def _find_transition(
        self,
        from_state: ServerState,
        event: StateEvent,
        conn: Connection,
    ) -> Optional[Transition]:
        for transition in TRANSITIONS:
            if from_state not in transition.from_states:
                continue
            if transition.event != event:
                continue
            if not self._check_condition(transition.condition, conn):
                continue
            return transition
        return None

    def _apply_event(self, event: StateEvent, conn: Connection) -> None:
        if event == StateEvent.RW_CONNECT:
            self._rw_conn = conn
            self._committed = False
        elif event == StateEvent.RW_COMMIT:
            self._committed = True
            self._rw_conn = None
        elif event == StateEvent.RW_ABORT:
            self._rw_conn = None
        elif event == StateEvent.RO_CONNECT:
            self._ro_conns.add(conn)
        elif event == StateEvent.RO_DISCONNECT:
            self._ro_conns.discard(conn)

    def transition(self, event: StateEvent, conn: Connection) -> ServerState:
        transition = self._find_transition(self.state, event, conn)
        if transition is None:
            raise InvalidTransition(
                f"No transition for {event.name} from state {self.state.name} "
                f"(session={conn.session_id})"
            )
        self._apply_event(event, conn)
        return self.state

    def check_operation(self, msg_type: type, conn: Connection) -> None:
        if conn.mode == GrantedLockType.RW and msg_type not in RW_ALLOWED:
            raise OperationNotAllowed(
                f"{msg_type.__name__} not allowed for RW session in state {self.state.name}"
            )
        if conn.mode == GrantedLockType.RO and msg_type not in RO_ALLOWED:
            raise OperationNotAllowed(
                f"{msg_type.__name__} not allowed for RO session in state {self.state.name}"
            )
        if msg_type in RW_REQUIRED and conn.mode != GrantedLockType.RW:
            raise OperationNotAllowed(
                f"{msg_type.__name__} requires RW session, got {conn.mode.value}"
            )

    def can_acquire_rw(self) -> bool:
        return self._rw_conn is None and not self._ro_conns

    def can_acquire_ro(self, waiting_writers: int) -> bool:
        return self._committed and self._rw_conn is None and waiting_writers == 0


@dataclass(frozen=True)
class SessionSnapshot:
    state: ServerState
    has_rw_session: bool
    ro_session_count: int
    waiting_writers: int
    committed: bool
    is_ready: bool


class GMSSessionManager:
    """Owns lock transitions, waiter coordination, and cleanup."""

    def __init__(self):
        self._locking = GMSLocalFSM()
        self._waiting_writers = 0
        self._reserved_rw_session_id: Optional[str] = None
        self._condition = asyncio.Condition()
        self._next_session_id = 0

    @property
    def state(self) -> ServerState:
        return self._locking.state

    def next_session_id(self) -> str:
        self._next_session_id += 1
        return f"session_{self._next_session_id}"

    def snapshot(self) -> SessionSnapshot:
        has_rw_session = self._locking.rw_conn is not None
        return SessionSnapshot(
            state=self._locking.state,
            has_rw_session=has_rw_session,
            ro_session_count=self._locking.ro_count,
            waiting_writers=self._waiting_writers,
            committed=self._locking.committed,
            is_ready=self._locking.committed and not has_rw_session,
        )

    def _can_grant_rw(self) -> bool:
        return self._reserved_rw_session_id is None and self._locking.can_acquire_rw()

    def _can_grant_ro(self) -> bool:
        return self._reserved_rw_session_id is None and self._locking.can_acquire_ro(
            self._waiting_writers
        )

    def _can_grant_rw_or_ro(self) -> bool:
        if self._can_grant_ro():
            return True
        return self._can_grant_rw() and not self._locking.committed

    async def acquire_lock(
        self,
        mode: RequestedLockType,
        timeout_ms: Optional[int],
        session_id: str,
    ) -> Optional[GrantedLockType]:
        timeout = timeout_ms / 1000 if timeout_ms is not None else None

        if mode == RequestedLockType.RW:
            try:
                async with self._condition:
                    self._waiting_writers += 1
                    try:
                        await asyncio.wait_for(
                            self._condition.wait_for(self._can_grant_rw),
                            timeout=timeout,
                        )
                    except asyncio.TimeoutError:
                        return None
                    self._reserved_rw_session_id = session_id
                    return GrantedLockType.RW
            finally:
                async with self._condition:
                    self._waiting_writers -= 1
                    self._condition.notify_all()

        if mode == RequestedLockType.RO:
            async with self._condition:
                try:
                    await asyncio.wait_for(
                        self._condition.wait_for(self._can_grant_ro),
                        timeout=timeout,
                    )
                except asyncio.TimeoutError:
                    return None
            return GrantedLockType.RO

        async with self._condition:
            if self._can_grant_rw() and not self._locking.committed:
                self._reserved_rw_session_id = session_id
                return GrantedLockType.RW
            try:
                await asyncio.wait_for(
                    self._condition.wait_for(self._can_grant_rw_or_ro),
                    timeout=timeout,
                )
            except asyncio.TimeoutError:
                return None
            if self._can_grant_rw() and not self._locking.committed:
                self._reserved_rw_session_id = session_id
                return GrantedLockType.RW
        return GrantedLockType.RO

    async def cancel_connect(
        self,
        session_id: str,
        mode: Optional[GrantedLockType],
    ) -> None:
        if mode != GrantedLockType.RW:
            return
        async with self._condition:
            if self._reserved_rw_session_id == session_id:
                self._reserved_rw_session_id = None
                self._condition.notify_all()

    def on_connect(self, conn: Connection) -> None:
        if conn.mode == GrantedLockType.RW:
            if self._reserved_rw_session_id != conn.session_id:
                raise AssertionError(
                    f"RW session {conn.session_id} was not reserved before connect"
                )
            self._reserved_rw_session_id = None
        event = (
            StateEvent.RW_CONNECT
            if conn.mode == GrantedLockType.RW
            else StateEvent.RO_CONNECT
        )
        self._locking.transition(event, conn)

    def on_commit(self, conn: Connection) -> None:
        self._locking.transition(StateEvent.RW_COMMIT, conn)

    def check_operation(self, msg_type: type, conn: Connection) -> None:
        self._locking.check_operation(msg_type, conn)

    def begin_cleanup(self, conn: Optional[Connection]) -> StateEvent | None:
        if conn is None:
            return None

        event = None
        if conn.mode == GrantedLockType.RW:
            if self._locking.rw_conn is conn and not self._locking.committed:
                self._locking.transition(StateEvent.RW_ABORT, conn)
                event = StateEvent.RW_ABORT
        elif conn in self._locking.ro_conns:
            self._locking.transition(StateEvent.RO_DISCONNECT, conn)
            event = StateEvent.RO_DISCONNECT
        return event

    async def finish_cleanup(self, conn: Optional[Connection]) -> None:
        if conn is not None:
            await conn.close()
        async with self._condition:
            self._condition.notify_all()