rpc.py 4.52 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
"""Internal GPU Memory Service transport.
5

6
7
This module only owns Unix socket transport and typed request/response exchange.
Session semantics live in `gpu_memory_service.client.session`.
8
9
"""

10
11
from __future__ import annotations

12
import logging
13
import os
14
import socket
15
from typing import Optional, Tuple, Type, TypeVar
16
17
18
19
20
21
22

from gpu_memory_service.common.protocol.messages import (
    ErrorResponse,
    HandshakeRequest,
    HandshakeResponse,
)
from gpu_memory_service.common.protocol.wire import recv_message_sync, send_message_sync
23
from gpu_memory_service.common.types import RequestedLockType
24
25
26
27
28
29

T = TypeVar("T")

logger = logging.getLogger(__name__)


30
31
class _GMSRPCTransport:
    """Raw GMS Unix socket transport."""
32

33
    def __init__(self, socket_path: str):
34
35
36
37
        self.socket_path = socket_path
        self._socket: Optional[socket.socket] = None
        self._recv_buffer = bytearray()

38
39
40
    @property
    def is_connected(self) -> bool:
        return self._socket is not None
41

42
    def connect(self) -> None:
43
44
45
46
        self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        try:
            self._socket.connect(self.socket_path)
        except FileNotFoundError:
47
48
            self._socket.close()
            self._socket = None
49
50
51
52
            raise ConnectionError(
                f"GMS server not running at {self.socket_path}"
            ) from None
        except Exception as exc:
53
54
            self._socket.close()
            self._socket = None
55
            raise ConnectionError(f"Failed to connect to GMS: {exc}") from exc
56

57
58
59
60
61
62
63
64
65
    def handshake(
        self,
        lock_type: RequestedLockType,
        timeout_ms: Optional[int],
    ) -> HandshakeResponse:
        response, _ = self.request_with_fd(
            HandshakeRequest(lock_type=lock_type, timeout_ms=timeout_ms),
            HandshakeResponse,
            error_prefix="GMS handshake",
66
67
68
        )
        return response

69
70
71
72
73
74
75
76
    def request(self, request, response_type: Type[T]) -> T:
        response, fd = self.request_with_fd(request, response_type)
        if fd >= 0:
            os.close(fd)
            raise RuntimeError(
                f"GMS request {type(request).__name__} returned an unexpected FD"
            )
        return response
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    def request_with_fd(
        self,
        request,
        response_type: Type[T],
        *,
        error_prefix: Optional[str] = None,
    ) -> Tuple[T, int]:
        response, fd = self._send_recv(request, error_prefix=error_prefix)
        if not isinstance(response, response_type):
            prefix = error_prefix or f"GMS request {type(request).__name__}"
            if fd >= 0:
                os.close(fd)
            raise RuntimeError(
                f"{prefix} returned unexpected response type: {type(response)}"
            )
        return response, fd
94

95
96
97
98
99
    def _send_recv(
        self, request, *, error_prefix: Optional[str] = None
    ) -> Tuple[object, int]:
        if self._socket is None:
            raise RuntimeError("Attempted GMS request on disconnected transport")
100

101
        prefix = error_prefix or f"GMS request {type(request).__name__}"
102
        try:
103
104
105
            send_message_sync(self._socket, request)
            response, fd, self._recv_buffer = recv_message_sync(
                self._socket, self._recv_buffer
106
            )
107
        except Exception as exc:
108
109
110
111
112
            try:
                self._socket.close()
            except Exception:
                pass
            self._socket = None
113
            raise ConnectionError(f"{prefix} failed: {exc}") from exc
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        if isinstance(response, ErrorResponse):
            if fd >= 0:
                os.close(fd)
            raise RuntimeError(f"{prefix} error: {response.error}")
        return response, fd

    def close(self) -> None:
        if self._socket is None:
            return
        try:
            self._socket.close()
        except Exception as exc:
            raise ConnectionError(
                f"Failed to close GMS transport socket: {exc}"
            ) from exc
        finally:
            self._socket = None

    def __enter__(self) -> "_GMSRPCTransport":
134
135
136
137
138
139
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.close()

    def __del__(self):
140
141
142
143
144
145
146
        if self._socket is not None:
            try:
                self._socket.close()
            except Exception:
                pass
            self._socket = None
            logger.warning("_GMSRPCTransport not closed properly")