"tests/fault_tolerance/vscode:/vscode.git/clone" did not exist on "aedfc0a37b23365b3f75a029fc21ada04f6321f6"
wire.py 5.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Wire protocol for length-prefixed messages with optional FD passing."""

import asyncio
import os
import socket
import struct
from typing import Optional, Tuple

from .messages import Message, decode_message, encode_message

HEADER_SIZE = 4  # 4-byte big-endian length prefix


def _frame_message(msg: Message) -> bytes:
    """Encode and frame a message with length prefix."""
    data = encode_message(msg)
    return struct.pack("!I", len(data)) + data


def _try_extract_message(
    recv_buffer: bytearray,
) -> Tuple[Optional[Message], bytearray, int]:
    """Try to extract a complete message from buffer.

    Returns (message, remaining_buffer, bytes_needed).
    """
    if len(recv_buffer) < HEADER_SIZE:
        return None, recv_buffer, HEADER_SIZE - len(recv_buffer)

    length = struct.unpack("!I", bytes(recv_buffer[:HEADER_SIZE]))[0]
    total_needed = HEADER_SIZE + length

    if len(recv_buffer) < total_needed:
        return None, recv_buffer, total_needed - len(recv_buffer)

    msg_data = bytes(recv_buffer[HEADER_SIZE:total_needed])
    remaining = bytearray(recv_buffer[total_needed:])
    return decode_message(msg_data), remaining, 0


# ==================== Async (for server) ====================


async def send_message(writer, msg: Message, fd: int = -1) -> None:
    """Send a length-prefixed message with optional FD via SCM_RIGHTS."""
    frame = _frame_message(msg)

    if fd >= 0:
        transport_sock = writer.get_extra_info("socket")
        if transport_sock is None:
            raise RuntimeError("Cannot get socket from transport for FD passing")

        def do_send_fd():
            raw_fd = transport_sock.fileno()
            dup_fd = os.dup(raw_fd)
            try:
                sock = socket.socket(fileno=dup_fd)
                try:
                    sock.setblocking(True)
                    socket.send_fds(sock, [frame], [fd])
                finally:
                    sock.detach()
            except Exception:
                os.close(dup_fd)
                raise

        await asyncio.get_running_loop().run_in_executor(None, do_send_fd)
    else:
        writer.write(frame)
        await writer.drain()


async def recv_message(
    reader, recv_buffer: Optional[bytearray] = None, raw_sock=None
) -> Tuple[Optional[Message], int, bytearray]:
    """Receive a length-prefixed message with optional FD.

    Returns (message, fd, remaining_buffer). fd is -1 if none sent.
    """
    if recv_buffer is None:
        recv_buffer = bytearray()

    # Check if complete message already in buffer
    msg, remaining, _ = _try_extract_message(recv_buffer)
    if msg is not None:
        return msg, -1, remaining

    loop = asyncio.get_running_loop()
    fd = -1

    # Receive more data
    if raw_sock is not None:
        raw_msg, fds, _flags, _addr = await loop.run_in_executor(
            None, lambda: socket.recv_fds(raw_sock, 65536, 1)
        )
        if not raw_msg:
            raise ConnectionResetError("Connection closed")
        recv_buffer.extend(raw_msg)
        fd = fds[0] if fds else -1
    else:
        chunk = await reader.read(65536)
        if not chunk:
            raise ConnectionResetError("Connection closed")
        recv_buffer.extend(chunk)

    # Try to extract message, read more if needed
    msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
    while msg is None and bytes_needed > 0:
        if raw_sock is not None:
            # Continue reading from raw socket to avoid buffer inconsistency
            chunk = await loop.run_in_executor(
                None, lambda n=bytes_needed: raw_sock.recv(n)
            )
        else:
            chunk = await reader.read(bytes_needed)
        if not chunk:
            raise ConnectionResetError("Connection closed")
        remaining.extend(chunk)
        msg, remaining, bytes_needed = _try_extract_message(remaining)

    return msg, fd, remaining


# ==================== Sync (for client) ====================


def send_message_sync(sock, msg: Message, fd: int = -1) -> None:
    """Send a length-prefixed message with optional FD via SCM_RIGHTS."""
    frame = _frame_message(msg)
    if fd >= 0:
        socket.send_fds(sock, [frame], [fd])
    else:
        sock.sendall(frame)


def recv_message_sync(
    sock, recv_buffer: Optional[bytearray] = None
) -> Tuple[Optional[Message], int, bytearray]:
    """Receive a length-prefixed message with optional FD.

    Returns (message, fd, remaining_buffer). fd is -1 if none sent.
    """
    if recv_buffer is None:
        recv_buffer = bytearray()

    # Check if complete message already in buffer
    msg, remaining, _ = _try_extract_message(recv_buffer)
    if msg is not None:
        return msg, -1, remaining

    # Receive more data (with potential FD)
    raw_msg, fds, _flags, _addr = socket.recv_fds(sock, 65536, 1)
    if not raw_msg:
        raise ConnectionResetError("Connection closed")
    recv_buffer.extend(raw_msg)
    fd = fds[0] if fds else -1

    # Try to extract message, read more if needed
    msg, remaining, bytes_needed = _try_extract_message(recv_buffer)
    while msg is None and bytes_needed > 0:
        chunk = sock.recv(bytes_needed)
        if not chunk:
            raise ConnectionResetError("Connection closed")
        remaining.extend(chunk)
        msg, remaining, bytes_needed = _try_extract_message(remaining)

    return msg, fd, remaining