".github/vscode:/vscode.git/clone" did not exist on "bb9566b70929dcf9f0787438dcd7391c8a884397"
wire.py 5.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Wire protocol for length-prefixed messages with optional FD passing."""

import asyncio
import os
import socket
import struct
from typing import Optional, Tuple

from .messages import Message, decode_message, encode_message

HEADER_SIZE = 4  # 4-byte big-endian length prefix


def _frame_message(msg: Message) -> bytes:
    """Encode and frame a message with length prefix."""
    data = encode_message(msg)
    return struct.pack("!I", len(data)) + data


def _try_extract_message(
    recv_buffer: bytearray,
) -> Tuple[Optional[Message], bytearray, int]:
    """Try to extract a complete message from buffer.

    Returns (message, remaining_buffer, bytes_needed).
    """
    if len(recv_buffer) < HEADER_SIZE:
        return None, recv_buffer, HEADER_SIZE - len(recv_buffer)

    length = struct.unpack("!I", bytes(recv_buffer[:HEADER_SIZE]))[0]
    total_needed = HEADER_SIZE + length

    if len(recv_buffer) < total_needed:
        return None, recv_buffer, total_needed - len(recv_buffer)

    msg_data = bytes(recv_buffer[HEADER_SIZE:total_needed])
    remaining = bytearray(recv_buffer[total_needed:])
    return decode_message(msg_data), remaining, 0


# ==================== Async (for server) ====================


async def send_message(writer, msg: Message, fd: int = -1) -> None:
    """Send a length-prefixed message with optional FD via SCM_RIGHTS."""
    frame = _frame_message(msg)

    if fd >= 0:
        transport_sock = writer.get_extra_info("socket")
        if transport_sock is None:
            raise RuntimeError("Cannot get socket from transport for FD passing")

        def do_send_fd():
            raw_fd = transport_sock.fileno()
            dup_fd = os.dup(raw_fd)
            try:
                sock = socket.socket(fileno=dup_fd)
                try:
                    sock.setblocking(True)
                    socket.send_fds(sock, [frame], [fd])
                finally:
                    sock.detach()
            except Exception:
                os.close(dup_fd)
                raise

        await asyncio.get_running_loop().run_in_executor(None, do_send_fd)
    else:
        writer.write(frame)
        await writer.drain()


async def recv_message(
    reader, recv_buffer: Optional[bytearray] = None, raw_sock=None
) -> Tuple[Optional[Message], int, bytearray]:
    """Receive a length-prefixed message with optional FD.

    Returns (message, fd, remaining_buffer). fd is -1 if none sent.
    """
    if recv_buffer is None:
        recv_buffer = bytearray()

    # Check if complete message already in buffer
    msg, remaining, _ = _try_extract_message(recv_buffer)
    if msg is not None:
        return msg, -1, remaining

    loop = asyncio.get_running_loop()
    fd = -1

    # Receive more data
    if raw_sock is not None:
        raw_msg, fds, _flags, _addr = await loop.run_in_executor(
            None, lambda: socket.recv_fds(raw_sock, 65536, 1)
        )
99
100
        for extra_fd in fds[1:]:
            os.close(extra_fd)
101
        if not raw_msg:
102
103
            if fds:
                os.close(fds[0])
104
105
106
107
108
109
110
111
112
113
            raise ConnectionResetError("Connection closed")
        recv_buffer.extend(raw_msg)
        fd = fds[0] if fds else -1
    else:
        chunk = await reader.read(65536)
        if not chunk:
            raise ConnectionResetError("Connection closed")
        recv_buffer.extend(chunk)

    # Try to extract message, read more if needed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    try:
        msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
        while msg is None and bytes_needed > 0:
            if raw_sock is not None:
                # Continue reading from raw socket to avoid buffer inconsistency
                chunk = await loop.run_in_executor(
                    None, lambda n=bytes_needed: raw_sock.recv(n)
                )
            else:
                chunk = await reader.read(bytes_needed)
            if not chunk:
                raise ConnectionResetError("Connection closed")
            remaining.extend(chunk)
            msg, remaining, bytes_needed = _try_extract_message(remaining)
        return msg, fd, remaining
    except Exception:
        if fd >= 0:
            os.close(fd)
        raise
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163


# ==================== Sync (for client) ====================


def send_message_sync(sock, msg: Message, fd: int = -1) -> None:
    """Send a length-prefixed message with optional FD via SCM_RIGHTS."""
    frame = _frame_message(msg)
    if fd >= 0:
        socket.send_fds(sock, [frame], [fd])
    else:
        sock.sendall(frame)


def recv_message_sync(
    sock, recv_buffer: Optional[bytearray] = None
) -> Tuple[Optional[Message], int, bytearray]:
    """Receive a length-prefixed message with optional FD.

    Returns (message, fd, remaining_buffer). fd is -1 if none sent.
    """
    if recv_buffer is None:
        recv_buffer = bytearray()

    # Check if complete message already in buffer
    msg, remaining, _ = _try_extract_message(recv_buffer)
    if msg is not None:
        return msg, -1, remaining

    # Receive more data (with potential FD)
    raw_msg, fds, _flags, _addr = socket.recv_fds(sock, 65536, 1)
164
165
    for extra_fd in fds[1:]:
        os.close(extra_fd)
166
    if not raw_msg:
167
168
        if fds:
            os.close(fds[0])
169
170
171
172
173
        raise ConnectionResetError("Connection closed")
    recv_buffer.extend(raw_msg)
    fd = fds[0] if fds else -1

    # Try to extract message, read more if needed
174
175
176
177
178
179
180
181
182
183
184
185
186
    try:
        msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
        while msg is None and bytes_needed > 0:
            chunk = sock.recv(bytes_needed)
            if not chunk:
                raise ConnectionResetError("Connection closed")
            remaining.extend(chunk)
            msg, remaining, bytes_needed = _try_extract_message(remaining)
        return msg, fd, remaining
    except Exception:
        if fd >= 0:
            os.close(fd)
        raise