"vllm/vscode:/vscode.git/clone" did not exist on "0a5738672158c07d5d66ac9f8c9e8876f2939bb9"
test_gms_client_transport.py 3.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest
from gpu_memory_service.client.rpc import _GMSRPCTransport
from gpu_memory_service.common.protocol import wire
from gpu_memory_service.common.protocol.messages import (
    CommitResponse,
    ErrorResponse,
    HandshakeResponse,
)

pytestmark = [
    pytest.mark.pre_merge,
    pytest.mark.unit,
18
    pytest.mark.none,
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    pytest.mark.gpu_0,
]


class _DummySocket:
    def __init__(self) -> None:
        self.closed = False

    def close(self) -> None:
        self.closed = True


def test_transport_failure_closes_socket_and_marks_disconnected(monkeypatch):
    transport = _GMSRPCTransport("/tmp/gms-test.sock")
    transport._socket = _DummySocket()

    monkeypatch.setattr(
        "gpu_memory_service.client.rpc.send_message_sync",
        lambda sock, request: None,
    )
    monkeypatch.setattr(
        "gpu_memory_service.client.rpc.recv_message_sync",
        lambda sock, buffer: (_ for _ in ()).throw(BrokenPipeError("boom")),
    )

    with pytest.raises(ConnectionError, match="failed: boom"):
        transport.request(CommitResponse(success=True), HandshakeResponse)

    assert not transport.is_connected
    assert transport._socket is None


def test_request_with_fd_closes_fd_on_unexpected_response_type(monkeypatch):
    transport = _GMSRPCTransport("/tmp/gms-test.sock")
    closed_fds: list[int] = []

    monkeypatch.setattr(
        transport,
        "_send_recv",
        lambda request, error_prefix=None: (CommitResponse(success=True), 37),
    )
    monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)

    with pytest.raises(RuntimeError, match="unexpected response type"):
        transport.request_with_fd(
            CommitResponse(success=True),
            HandshakeResponse,
        )

    assert closed_fds == [37]


def test_request_closes_fd_on_error_response(monkeypatch):
    transport = _GMSRPCTransport("/tmp/gms-test.sock")
    transport._socket = _DummySocket()
    closed_fds: list[int] = []

    monkeypatch.setattr(
        "gpu_memory_service.client.rpc.send_message_sync",
        lambda sock, request: None,
    )
    monkeypatch.setattr(
        "gpu_memory_service.client.rpc.recv_message_sync",
        lambda sock, buffer: (ErrorResponse(error="boom"), 41, bytearray()),
    )
    monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)

    with pytest.raises(RuntimeError, match="error: boom"):
        transport.request(CommitResponse(success=True), HandshakeResponse)

    assert closed_fds == [41]


def test_request_closes_fd_on_unexpected_success_fd(monkeypatch):
    transport = _GMSRPCTransport("/tmp/gms-test.sock")
    closed_fds: list[int] = []

    monkeypatch.setattr(
        transport,
        "request_with_fd",
        lambda request, response_type: (CommitResponse(success=True), 43),
    )
    monkeypatch.setattr("gpu_memory_service.client.rpc.os.close", closed_fds.append)

    with pytest.raises(RuntimeError, match="unexpected FD"):
        transport.request(CommitResponse(success=True), CommitResponse)

    assert closed_fds == [43]


def test_recv_message_sync_closes_fd_on_decode_failure(monkeypatch):
    closed_fds: list[int] = []

    monkeypatch.setattr(
        wire.socket,
        "recv_fds",
        lambda sock, size, maxfds: (b"\x00\x00\x00\x01x", [53], 0, None),
    )
    monkeypatch.setattr(
        wire,
        "decode_message",
        lambda payload: (_ for _ in ()).throw(ValueError("bad frame")),
    )
    monkeypatch.setattr(
        "gpu_memory_service.common.protocol.wire.os.close",
        closed_fds.append,
    )

    with pytest.raises(ValueError, match="bad frame"):
        wire.recv_message_sync(object(), bytearray())

    assert closed_fds == [53]