test_gms_client_transport.py 3.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol import wire
from gpu_memory_service.common.protocol.messages import (
    CommitResponse,
    ErrorResponse,
    HandshakeResponse,
)

pytestmark = [
    pytest.mark.pre_merge,
    pytest.mark.unit,
    pytest.mark.gpu_0,
]


class _DummySocket:
    def __init__(self) -> None:
        self.closed = False

    def close(self) -> None:
        self.closed = True


def test_transport_failure_closes_socket_and_marks_disconnected(monkeypatch):
    transport = _GMSRPCTransport("/tmp/gms-test.sock")
    transport._socket = _DummySocket()

    monkeypatch.setattr(
        "gpu_memory_service.client.rpc.send_message_sync",
        lambda sock, request: None,
    )
    monkeypatch.setattr(
        "gpu_memory_service.client.rpc.recv_message_sync",
        lambda sock, buffer: (_ for _ in ()).throw(BrokenPipeError("boom")),
    )

    with pytest.raises(ConnectionError, match="failed: boom"):
        transport.request(CommitResponse(success=True), HandshakeResponse)

    assert not transport.is_connected
    assert transport._socket is None


def test_request_with_fd_closes_fd_on_unexpected_response_type(monkeypatch):
    transport = _GMSRPCTransport("/tmp/gms-test.sock")
    closed_fds: list[int] = []

    monkeypatch.setattr(
        transport,
        "_send_recv",
        lambda request, error_prefix=None: (CommitResponse(success=True), 37),
    )
    monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)

    with pytest.raises(RuntimeError, match="unexpected response type"):
        transport.request_with_fd(
            CommitResponse(success=True),
            HandshakeResponse,
        )

    assert closed_fds == [37]


def test_request_closes_fd_on_error_response(monkeypatch):
    transport = _GMSRPCTransport("/tmp/gms-test.sock")
    transport._socket = _DummySocket()
    closed_fds: list[int] = []

    monkeypatch.setattr(
        "gpu_memory_service.client.rpc.send_message_sync",
        lambda sock, request: None,
    )
    monkeypatch.setattr(
        "gpu_memory_service.client.rpc.recv_message_sync",
        lambda sock, buffer: (ErrorResponse(error="boom"), 41, bytearray()),
    )
    monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)

    with pytest.raises(RuntimeError, match="error: boom"):
        transport.request(CommitResponse(success=True), HandshakeResponse)

    assert closed_fds == [41]


def test_request_closes_fd_on_unexpected_success_fd(monkeypatch):
    transport = _GMSRPCTransport("/tmp/gms-test.sock")
    closed_fds: list[int] = []

    monkeypatch.setattr(
        transport,
        "request_with_fd",
        lambda request, response_type: (CommitResponse(success=True), 43),
    )
    monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)

    with pytest.raises(RuntimeError, match="unexpected FD"):
        transport.request(CommitResponse(success=True), CommitResponse)

    assert closed_fds == [43]


def test_recv_message_sync_closes_fd_on_decode_failure(monkeypatch):
    closed_fds: list[int] = []

    monkeypatch.setattr(
        wire.socket,
        "recv_fds",
        lambda sock, size, maxfds: (b"\x00\x00\x00\x01x", [53], 0, None),
    )
    monkeypatch.setattr(
        wire,
        "decode_message",
        lambda payload: (_ for _ in ()).throw(ValueError("bad frame")),
    )
    monkeypatch.setattr(
        "gpu_memory_service.common.protocol.wire.os.close",
        closed_fds.append,
    )

    with pytest.raises(ValueError, match="bad frame"):
        wire.recv_message_sync(object(), bytearray())

    assert closed_fds == [53]