conftest.py 5.34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
import random
from typing import Optional, Union

import msgspec
import msgspec.msgpack
import pytest
import zmq

11
from vllm.config.kv_events import KVEventsConfig
12
13
14
15
from vllm.distributed.kv_events import EventPublisherFactory

from .test_events import SampleBatch

16
17
DP_RANK = 0

18
19
20
21

@pytest.fixture
def random_port():
    """Generate a random port number for testing"""
22
    return random.randint(10000, 59900)
23
24
25
26
27
28
29
30
31
32
33
34


@pytest.fixture
def publisher_config(random_port, request):
    """Create a publisher config with inproc transport"""
    how = request.param if hasattr(request, "param") else "inproc"

    if how == "inproc":
        endpoint = f"inproc://test-{random_port}"
        replay_endpoint = endpoint + "-replay"
    else:
        endpoint = f"tcp://*:{random_port}"
35
        replay_endpoint = f"tcp://*:{random_port + 100}"
36

37
38
39
40
41
42
43
44
45
    return KVEventsConfig(
        enable_kv_cache_events=True,
        publisher="zmq",
        endpoint=endpoint,
        replay_endpoint=replay_endpoint,
        buffer_steps=100,
        hwm=1000,
        topic="test",
    )
46
47
48
49
50


@pytest.fixture
def publisher(publisher_config):
    """Create and return a publisher instance"""
51
    pub = EventPublisherFactory.create(publisher_config, DP_RANK)
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    yield pub
    pub.shutdown()


@pytest.fixture
def subscriber(publisher_config):
    """Create and return a subscriber for testing"""
    endpoint = publisher_config.endpoint
    replay_endpoint = publisher_config.replay_endpoint

    if endpoint.startswith("tcp://*"):
        endpoint = endpoint.replace("*", "127.0.0.1")
    if replay_endpoint and replay_endpoint.startswith("tcp://*"):
        replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")

67
68
69
70
71
    sub = MockSubscriber(
        [endpoint],
        [replay_endpoint] if replay_endpoint else None,
        publisher_config.topic,
    )
72
73
74
75
76
77
78
    yield sub
    sub.close()


class MockSubscriber:
    """Helper class to receive and verify published events"""

79
80
81
82
83
84
85
    def __init__(
        self,
        pub_endpoints: Union[str, list[str]],
        replay_endpoints: Optional[Union[str, list[str]]] = None,
        topic: str = "",
        decode_type=SampleBatch,
    ):
86
87
        self.ctx = zmq.Context.instance()

88
89
90
91
92
        # Convert single endpoint to list for consistency
        if isinstance(pub_endpoints, str):
            pub_endpoints = [pub_endpoints]
        if isinstance(replay_endpoints, str):
            replay_endpoints = [replay_endpoints]
93

94
95
96
97
98
99
100
101
102
103
104
105
106
        # Set up subscriber socket - connect to all endpoints
        self.sub = self.ctx.socket(zmq.SUB)
        self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode("utf-8"))
        for endpoint in pub_endpoints:
            self.sub.connect(endpoint)

        # Set up replay sockets if provided
        self.replay_sockets = []
        if replay_endpoints:
            for replay_endpoint in replay_endpoints:
                replay = self.ctx.socket(zmq.REQ)
                replay.connect(replay_endpoint)
                self.replay_sockets.append(replay)
107
108

        self.topic = topic
109
        self.topic_bytes = topic.encode("utf-8")
110
111
112
113
        self.received_msgs: list[tuple[int, SampleBatch]] = []
        self.last_seq = -1
        self.decoder = msgspec.msgpack.Decoder(type=decode_type)

114
    def receive_one(self, timeout=1000) -> Union[tuple[int, SampleBatch], None]:
115
116
117
118
119
120
121
122
123
124
125
126
127
        """Receive a single message with timeout"""
        if not self.sub.poll(timeout):
            return None

        topic_bytes, seq_bytes, payload = self.sub.recv_multipart()
        assert topic_bytes == self.topic_bytes

        seq = int.from_bytes(seq_bytes, "big")
        data = self.decoder.decode(payload)
        self.last_seq = seq
        self.received_msgs.append((seq, data))
        return seq, data

128
    def request_replay(self, start_seq: int, socket_idx: int = 0) -> None:
129
        """Request replay of messages starting from start_seq"""
130
131
132
133
134
135
136
        if not self.replay_sockets:
            raise ValueError("Replay sockets not initialized")
        if socket_idx >= len(self.replay_sockets):
            raise ValueError(f"Invalid socket index {socket_idx}")

        self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))

137
    def receive_replay(self, socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
138
139
140
141
142
143
144
        """Receive replayed messages from a specific replay socket"""
        if not self.replay_sockets:
            raise ValueError("Replay sockets not initialized")
        if socket_idx >= len(self.replay_sockets):
            raise ValueError(f"Invalid socket index {socket_idx}")

        replay_socket = self.replay_sockets[socket_idx]
145
146
147
        replayed: list[tuple[int, SampleBatch]] = []
        while True:
            try:
148
                if not replay_socket.poll(1000):
149
150
                    break

151
                frames = replay_socket.recv_multipart()
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
                if not frames or not frames[-1]:
                    # End of replay marker
                    break

                seq_bytes, payload = frames
                seq = int.from_bytes(seq_bytes, "big")
                data = self.decoder.decode(payload)
                replayed.append((seq, data))
            except zmq.ZMQError as _:
                break

        return replayed

    def close(self):
        """Clean up resources"""
        self.sub.close()
168
169
        for replay in self.replay_sockets:
            replay.close()