test_shm_broadcast.py 12.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import random
5
import threading
6
import time
7
from unittest import mock
8

9
import multiprocess as mp
10
import numpy as np
11
import pytest
12
13
import torch.distributed as dist

14
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
15
from vllm.distributed.utils import StatelessProcessGroup
16
from vllm.utils.network_utils import get_open_port
17
from vllm.utils.system_utils import update_environment_variables
18
19


20
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
21
22
23
24
25
26
27
    np.random.seed(seed)
    sizes = np.random.randint(1, 10_000, n)
    # on average, each array will have 5k elements
    # with int64, each array will have 40kb
    return [np.random.randint(1, 100, i) for i in sizes]


28
29
30
31
32
33
34
35
def distributed_run(fn, world_size, timeout=60):
    """Run a function in multiple processes with proper error handling.

    Args:
        fn: Function to run in each process
        world_size: Number of processes to spawn
        timeout: Maximum time in seconds to wait for processes (default: 60)
    """
36
37
38
39
    number_of_processes = world_size
    processes = []
    for i in range(number_of_processes):
        env = {}
40
41
42
43
44
45
        env["RANK"] = str(i)
        env["LOCAL_RANK"] = str(i)
        env["WORLD_SIZE"] = str(number_of_processes)
        env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
        env["MASTER_ADDR"] = "localhost"
        env["MASTER_PORT"] = "12345"
46
        p = mp.Process(target=fn, args=(env,))
47
48
49
        processes.append(p)
        p.start()

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    # Monitor processes and fail fast if any process fails
    start_time = time.time()
    failed_processes = []

    # Wait for all processes, checking for failures
    while time.time() - start_time < timeout:
        all_done = True
        for i, p in enumerate(processes):
            if p.is_alive():
                all_done = False
            elif p.exitcode != 0:
                # Process failed
                failed_processes.append((i, p.exitcode))
                break

        if failed_processes or all_done:
            break
        time.sleep(0.1)  # Check every 100ms
68

69
70
71
72
73
74
75
76
77
78
79
80
    # Check for timeout if no failures detected yet
    for i, p in enumerate(processes):
        if p.is_alive():
            p.kill()
            p.join()

    # Report failures
    if failed_processes:
        error_msg = "Distributed test failed:\n"
        for rank, status in failed_processes:
            error_msg += f"  Rank {rank}: Exit code {status}\n"
        raise AssertionError(error_msg)
81
82
83


def worker_fn_wrapper(fn):
84
    # `mp.Process` cannot accept environment variables directly
85
86
87
88
89
90
91
92
93
94
95
96
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
    def wrapped_fn(env):
        update_environment_variables(env)
        dist.init_process_group(backend="gloo")
        fn()

    return wrapped_fn


@worker_fn_wrapper
def worker_fn():
97
98
99
    rank = dist.get_rank()
    if rank == 0:
        port = get_open_port()
100
        ip = "127.0.0.1"
101
        dist.broadcast_object_list([ip, port], src=0)
102
    else:
103
104
        recv = [None, None]
        dist.broadcast_object_list(recv, src=0)
105
        ip, port = recv  # type: ignore
106

107
    stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
108
109
110
111

    for pg in [dist.group.WORLD, stateless_pg]:
        writer_rank = 2
        broadcaster = MessageQueue.create_from_process_group(
112
113
            pg, 40 * 1024, 2, writer_rank
        )
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        if rank == writer_rank:
            seed = random.randint(0, 1000)
            dist.broadcast_object_list([seed], writer_rank)
        else:
            recv = [None]
            dist.broadcast_object_list(recv, writer_rank)
            seed = recv[0]  # type: ignore

        if pg == dist.group.WORLD:
            dist.barrier()
        else:
            pg.barrier()

        # in case we find a race condition
        # print the seed so that we can reproduce the error
        print(f"Rank {rank} got seed {seed}")
        # test broadcasting with about 400MB of data
        N = 10_000
        if rank == writer_rank:
            arrs = get_arrays(N, seed)
            for x in arrs:
                broadcaster.broadcast_object(x)
                time.sleep(random.random() / 1000)
        else:
            arrs = get_arrays(N, seed)
            for x in arrs:
                y = broadcaster.broadcast_object(None)
                assert np.array_equal(x, y)
                time.sleep(random.random() / 1000)

        if pg == dist.group.WORLD:
            dist.barrier()
146
            print(f"torch distributed passed the test! Rank {rank}")
147
148
        else:
            pg.barrier()
149
            print(f"StatelessProcessGroup passed the test! Rank {rank}")
150
151
152
153


def test_shm_broadcast():
    distributed_run(worker_fn, 4)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394


@worker_fn_wrapper
def worker_fn_test_shutdown_busy():
    rank = dist.get_rank()
    writer_rank = 2
    message_queue = MessageQueue.create_from_process_group(
        dist.group.WORLD, 40 * 1024, 2, writer_rank
    )

    if not message_queue._is_writer:
        # Put into busy mode
        message_queue._spin_condition.busy_loop_s = 9999

        shutdown_event = threading.Event()

        def shutdown_thread(mq, shutdown_event):
            shutdown_event.wait()
            mq.shutdown()

        threading.Thread(
            target=shutdown_thread, args=(message_queue, shutdown_event)
        ).start()

        with pytest.raises(TimeoutError):
            message_queue.dequeue(timeout=0.01)

        shutdown_event.set()

        with pytest.raises(RuntimeError, match="cancelled"):
            message_queue.dequeue(timeout=1)

        assert message_queue.shutting_down

    print(f"torch distributed passed the test! Rank {rank}")
    dist.barrier()


def test_message_queue_shutdown_busy(caplog_vllm):
    distributed_run(worker_fn_test_shutdown_busy, 4)
    print(caplog_vllm.text)


@worker_fn_wrapper
def worker_fn_test_shutdown_idle():
    rank = dist.get_rank()
    writer_rank = 2
    message_queue = MessageQueue.create_from_process_group(
        dist.group.WORLD, 40 * 1024, 2, writer_rank
    )

    if not message_queue._is_writer:
        # Put into idle mode
        message_queue._spin_condition.last_read = 0

        shutdown_event = threading.Event()

        def shutdown_thread(mq, shutdown_event):
            shutdown_event.wait()
            mq.shutdown()

        threading.Thread(
            target=shutdown_thread, args=(message_queue, shutdown_event)
        ).start()

        with pytest.raises(TimeoutError):
            message_queue.dequeue(timeout=0.01)

        shutdown_event.set()

        with pytest.raises(RuntimeError, match="cancelled"):
            message_queue.dequeue(timeout=1)

        assert message_queue.shutting_down

    print(f"torch distributed passed the test! Rank {rank}")
    dist.barrier()


def test_message_queue_shutdown_idle():
    distributed_run(worker_fn_test_shutdown_idle, 4)


@worker_fn_wrapper
def worker_fn_test_idle_to_busy():
    rank = dist.get_rank()
    writer_rank = 2
    message_queue = MessageQueue.create_from_process_group(
        dist.group.WORLD, 40 * 1024, 2, writer_rank
    )

    message1 = "hello world"
    message2 = np.random.randint(1, 100, 100)
    with mock.patch.object(
        message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
    ) as wrapped_wait:
        if not message_queue._is_writer:
            # Put into idle mode
            message_queue._spin_condition.last_read = 0

            # no messages, so expect a TimeoutError
            with pytest.raises(TimeoutError):
                message_queue.dequeue(timeout=0.01)
            # wait should only be called once while idle
            assert wrapped_wait.call_count == 1

            # sync with the writer and wait for message1
            dist.barrier()
            recv_message = message_queue.dequeue(timeout=5)
            assert recv_message == message1
            # second call to wait, with a message read, this puts in a busy spin
            assert wrapped_wait.call_count == 2

            # sync with the writer and wait for message2
            dist.barrier()
            recv_message = message_queue.dequeue(timeout=1)
            assert np.array_equal(recv_message, message2)
            # in busy mode, we expect wait to have been called multiple times
            assert wrapped_wait.call_count > 3
        else:
            # writer writes two messages in sync with the reader
            dist.barrier()
            # sleep delays the send to ensure reader enters the read loop
            time.sleep(0.1)
            message_queue.enqueue(message1)

            dist.barrier()
            time.sleep(0.1)
            message_queue.enqueue(message2)

    message_queue.shutdown()
    assert message_queue.shutting_down
    print(f"torch distributed passed the test! Rank {rank}")


def test_message_queue_idle_wake():
    distributed_run(worker_fn_test_idle_to_busy, 4)


@worker_fn_wrapper
def worker_fn_test_busy_to_idle():
    rank = dist.get_rank()
    writer_rank = 2
    message_queue = MessageQueue.create_from_process_group(
        dist.group.WORLD, 40 * 1024, 2, writer_rank
    )

    message1 = 12345
    message2 = list(range(3))
    with mock.patch.object(
        message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
    ) as wrapped_wait:
        if not message_queue._is_writer:
            # Put into busy mode
            message_queue._spin_condition.busy_loop_s = 9999

            # sync with the writer and wait for message1
            dist.barrier()
            recv_message = message_queue.dequeue(timeout=1)
            assert recv_message == message1
            # in busy mode, we expect wait to have been called many times
            assert wrapped_wait.call_count > 1

            # simulate busy loop ending
            message_queue._spin_condition.busy_loop_s = 0
            # ensure we enter idle mode, then record call count
            with pytest.raises(TimeoutError):
                message_queue.dequeue(timeout=0.01)
            call_count = wrapped_wait.call_count

            # sync with the writer and wait for message2
            dist.barrier()
            recv_message = message_queue.dequeue(timeout=1)
            assert recv_message == message2

            # call to wait after idle should only happen once
            assert wrapped_wait.call_count == call_count + 1
        else:
            # writer writes two messages in sync with the reader
            dist.barrier()
            # sleep delays the send to ensure reader enters the read loop
            time.sleep(0.1)
            message_queue.enqueue(message1)

            dist.barrier()
            time.sleep(0.1)
            message_queue.enqueue(message2)

    message_queue.shutdown()
    assert message_queue.shutting_down
    print(f"torch distributed passed the test! Rank {rank}")


def test_message_queue_busy_to_idle():
    distributed_run(worker_fn_test_busy_to_idle, 4)


def test_warning_logs(caplog_vllm):
    """
    Test that warning logs are emitted at VLLM_RINGBUFFER_WARNING_INTERVAL intervals
    when indefinite=False, and are not emitted when indefinite=True.
    """

    # Patch the warning log interval to every 1 ms during reads
    with mock.patch(
        "vllm.distributed.device_communicators.shm_broadcast.VLLM_RINGBUFFER_WARNING_INTERVAL",
        new=0.001,  # 1 ms
    ):
        writer = MessageQueue(
            n_reader=1,
            n_local_reader=1,
            max_chunk_bytes=1024 * 1024,  # 1MB chunks
            max_chunks=10,
        )
        reader = MessageQueue.create_from_handle(writer.export_handle(), rank=0)
        writer.wait_until_ready()
        reader.wait_until_ready()

        # We should have at least one warning log here
        # "0 seconds" expected due to rounding of 1ms test interval
        with pytest.raises(TimeoutError):
            reader.dequeue(timeout=0.01, indefinite=False)
        assert any(
            "No available shared memory broadcast block found in 0 seconds"
            in record.message
            for record in caplog_vllm.records
        )
        caplog_vllm.clear()

        # We should have no warnings this time
        with pytest.raises(TimeoutError):
            reader.dequeue(timeout=0.01, indefinite=True)
        assert all(
            "No available shared memory broadcast block found in 0 seconds"
            not in record.message
            for record in caplog_vllm.records
        )

        # Clean up when done
        writer.shutdown()
        reader.shutdown()