test_events.py 5.89 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import threading
import time

import msgspec
import pytest

from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
                                        NullEventPublisher)


class EventSample(
        msgspec.Struct,
        tag=True,  # type: ignore
        array_like=True  # type: ignore
):
    """Test event for publisher testing"""
    id: int
    value: str


class SampleBatch(EventBatch):
    """Test event batch for publisher testing"""
    events: list[EventSample]


def create_test_events(count: int) -> SampleBatch:
    """Create a batch of test events"""
    events = [EventSample(id=i, value=f"test-{i}") for i in range(count)]
    return SampleBatch(ts=time.time(), events=events)


def test_basic_publishing(publisher, subscriber):
    """Test basic event publishing works"""

    test_batch = create_test_events(5)
    publisher.publish(test_batch)

    result = subscriber.receive_one(timeout=1000)
    assert result is not None, "No message received"

    seq, received = result
    assert seq == 0, "Sequence number mismatch"
    assert received.ts == pytest.approx(test_batch.ts,
                                        abs=0.1), ("Timestamp mismatch")
    assert len(received.events) == len(
        test_batch.events), ("Number of events mismatch")

    for i, event in enumerate(received.events):
        assert event.id == i, "Event id mismatch"
        assert event.value == f"test-{i}", "Event value mismatch"


def test_multiple_events(publisher, subscriber):
    """Test publishing and receiving multiple event batches"""
    for _ in range(10):
        batch = create_test_events(2)
        publisher.publish(batch)

    received = []
    for _ in range(10):
        data = subscriber.receive_one(timeout=100)
        if data:
            received.append(data)

    assert len(received) == 10, "Number of messages mismatch"
    seqs = [seq for seq, _ in received]
    assert seqs == list(range(10)), "Sequence numbers mismatch"


def test_replay_mechanism(publisher, subscriber):
    """Test the replay mechanism works correctly"""
    for _ in range(19):
        batch = create_test_events(1)
        publisher.publish(batch)

    time.sleep(0.5)  # Need publisher to process above requests
    subscriber.request_replay(10)

    batch = create_test_events(1)
    publisher.publish(batch)  # 20th message

    replayed = subscriber.receive_replay()

    assert len(replayed) > 0, "No replayed messages received"
    seqs = [seq for seq, _ in replayed]
    assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
    assert seqs == list(range(min(seqs),
                              max(seqs) +
                              1)), ("Replayed messages not consecutive")


def test_buffer_limit(publisher, subscriber, publisher_config):
    """Test buffer limit behavior"""
    buffer_size = publisher_config.buffer_steps

    # Publish more events than the buffer can hold
    for i in range(buffer_size + 10):
        batch = create_test_events(1)
        publisher.publish(batch)

    time.sleep(0.5)  # Need publisher to process above requests
    subscriber.request_replay(0)

    batch = create_test_events(1)
    publisher.publish(batch)

    replayed = subscriber.receive_replay()

    assert len(replayed) <= buffer_size, "Can't replay more than buffer size"

    oldest_seq = min(seq for seq, _ in replayed)
    assert oldest_seq >= 10, "The oldest sequence should be at least 10"


def test_topic_filtering(publisher_config):
    """
    Test that a subscriber only receives messages matching its topic filter
    """
    publisher_config.replay_endpoint = None

123
124
    publisher_config.topic = "foo"
    pub = EventPublisherFactory.create(publisher_config)
125
126

    from .conftest import MockSubscriber
127
128
    sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
    sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar")
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    try:
        time.sleep(0.1)

        for _ in range(3):
            pub.publish(create_test_events(1))

        foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
        assert all(msg is not None for msg in foo_received), (
            "Subscriber with matching topic should receive messages")

        bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
        assert all(msg is None for msg in bar_received), (
            "Subscriber with non-matching topic should receive no messages")
    finally:
        pub.shutdown()
        sub_foo.close()
        sub_bar.close()


def test_high_volume(publisher, subscriber):
    """Test publishing and receiving a high volume of events"""
    num_batches = 10_000
    events_per_batch = 100

    # Publish events in a separate thread to not block
    def publish_events():
        for i in range(num_batches):
            batch = create_test_events(events_per_batch)
            publisher.publish(batch)
            # Small delay to avoid overwhelming
            if i % 100 == 0:
                time.sleep(0.01)

    received: list[tuple[int, SampleBatch]] = []

    publisher_thread = threading.Thread(target=publish_events)
    publisher_thread.start()

    start_time = time.time()
    while len(received) < num_batches:
        if time.time() - start_time > 10:  # Timeout after 10 seconds
            break

        result = subscriber.receive_one(timeout=100)
        if result:
            received.append(result)

    publisher_thread.join()

    assert len(received) >= num_batches * 0.9, (
        "We should have received most messages")

    seqs = [seq for seq, _ in received]
    assert sorted(seqs) == seqs, "Sequence numbers should be in order"


def test_null_publisher():
    """Test that NullEventPublisher can be used without errors"""
    publisher = NullEventPublisher()

    # This should not raise any errors
    batch = create_test_events(5)
    publisher.publish(batch)
    publisher.shutdown()