test_api_server_process_manager.py 8.56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11

import multiprocessing
import socket
import threading
import time
from unittest.mock import patch

import pytest

12
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

# Global variables to control worker behavior
WORKER_RUNTIME_SECONDS = 0.5


# Mock implementation of run_api_server_worker
def mock_run_api_server_worker(listen_address, sock, args, client_config=None):
    """Mock run_api_server_worker that runs for a specific time."""
    print(f"Mock worker started with client_config: {client_config}")
    time.sleep(WORKER_RUNTIME_SECONDS)
    print("Mock worker completed successfully")


@pytest.fixture
def api_server_args():
    """Fixture to provide arguments for APIServerProcessManager."""
    sock = socket.socket()
    return {
31
32
33
34
35
        "target_server_fn": mock_run_api_server_worker,
        "listen_address": "localhost:8000",
        "sock": sock,
        "args": "test_args",  # Simple string to avoid pickling issues
        "num_servers": 3,
36
        "input_addresses": [
37
38
39
            "tcp://127.0.0.1:5001",
            "tcp://127.0.0.1:5002",
            "tcp://127.0.0.1:5003",
40
41
        ],
        "output_addresses": [
42
43
44
            "tcp://127.0.0.1:6001",
            "tcp://127.0.0.1:6002",
            "tcp://127.0.0.1:6003",
45
        ],
46
        "stats_update_address": "tcp://127.0.0.1:7000",
47
48
49
50
51
52
53
54
55
56
    }


@pytest.mark.parametrize("with_stats_update", [True, False])
def test_api_server_process_manager_init(api_server_args, with_stats_update):
    """Test initializing the APIServerProcessManager."""
    # Set the worker runtime to ensure tests complete in reasonable time
    global WORKER_RUNTIME_SECONDS
    WORKER_RUNTIME_SECONDS = 0.5

57
    # Copy the args to avoid mutating them
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    args = api_server_args.copy()

    if not with_stats_update:
        args.pop("stats_update_address")
    manager = APIServerProcessManager(**args)

    try:
        # Verify the manager was initialized correctly
        assert len(manager.processes) == 3

        # Verify all processes are running
        for proc in manager.processes:
            assert proc.is_alive()

        print("Waiting for processes to run...")
        time.sleep(WORKER_RUNTIME_SECONDS / 2)

        # They should still be alive at this point
        for proc in manager.processes:
            assert proc.is_alive()

    finally:
        # Always clean up the processes
        print("Cleaning up processes...")
82
        manager.close()
83
84
85
86
87
88
89
90
91

        # Give processes time to terminate
        time.sleep(0.2)

        # Verify all processes were terminated
        for proc in manager.processes:
            assert not proc.is_alive()


92
93
94
@patch(
    "vllm.entrypoints.cli.serve.run_api_server_worker_proc", mock_run_api_server_worker
)
95
96
97
98
99
100
101
102
103
104
105
106
def test_wait_for_completion_or_failure(api_server_args):
    """Test that wait_for_completion_or_failure works with failures."""
    global WORKER_RUNTIME_SECONDS
    WORKER_RUNTIME_SECONDS = 1.0

    # Create the manager
    manager = APIServerProcessManager(**api_server_args)

    try:
        assert len(manager.processes) == 3

        # Create a result capture for the thread
107
        result: dict[str, Exception | None] = {"exception": None}
108
109
110
111
112
113
114
115

        def run_with_exception_capture():
            try:
                wait_for_completion_or_failure(api_server_manager=manager)
            except Exception as e:
                result["exception"] = e

        # Start a thread to run wait_for_completion_or_failure
116
        wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        wait_thread.start()

        # Let all processes run for a short time
        time.sleep(0.2)

        # All processes should still be running
        assert all(proc.is_alive() for proc in manager.processes)

        # Now simulate a process failure
        print("Simulating process failure...")
        manager.processes[0].terminate()

        # Wait for the wait_for_completion_or_failure
        # to detect and handle the failure
        # This should trigger it to terminate all other processes
        wait_thread.join(timeout=1.0)

        # The wait thread should have exited
        assert not wait_thread.is_alive()

        # Verify that an exception was raised with appropriate error message
        assert result["exception"] is not None
        assert "died with exit code" in str(result["exception"])

        # All processes should now be terminated
        for i, proc in enumerate(manager.processes):
            assert not proc.is_alive(), f"Process {i} should not be alive"

    finally:
146
        manager.close()
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        time.sleep(0.2)


@pytest.mark.timeout(30)
def test_normal_completion(api_server_args):
    """Test that wait_for_completion_or_failure works in normal completion."""
    global WORKER_RUNTIME_SECONDS
    WORKER_RUNTIME_SECONDS = 0.1

    # Create the manager
    manager = APIServerProcessManager(**api_server_args)

    try:
        # Give processes time to terminate
        # wait for processes to complete
        remaining_processes = manager.processes.copy()
        while remaining_processes:
            for proc in remaining_processes:
                if not proc.is_alive():
                    remaining_processes.remove(proc)
            time.sleep(0.1)

        # Verify all processes have terminated
        for i, proc in enumerate(manager.processes):
171
            assert not proc.is_alive(), f"Process {i} still alive after terminate()"
172
173
174
175
176

        # Now call wait_for_completion_or_failure
        # since all processes have already
        # terminated, it should return immediately
        # with no error
177
        wait_for_completion_or_failure(api_server_manager=manager)
178
179
180

    finally:
        # Clean up just in case
181
        manager.close()
182
183
184
185
186
187
188
189
190
191
192
193
        time.sleep(0.2)


@pytest.mark.timeout(30)
def test_external_process_monitoring(api_server_args):
    """Test that wait_for_completion_or_failure handles additional processes."""
    global WORKER_RUNTIME_SECONDS
    WORKER_RUNTIME_SECONDS = 100

    # Create and start the external process
    # (simulates local_engine_manager or coordinator)
    spawn_context = multiprocessing.get_context("spawn")
194
195
196
    external_proc = spawn_context.Process(
        target=mock_run_api_server_worker, name="MockExternalProcess"
    )
197
198
199
200
201
202
203
    external_proc.start()

    # Create the class to simulate a coordinator
    class MockCoordinator:
        def __init__(self, proc):
            self.proc = proc

204
        def close(self):
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
            if self.proc.is_alive():
                self.proc.terminate()
                self.proc.join(timeout=0.5)

    # Create a mock coordinator with the external process
    mock_coordinator = MockCoordinator(external_proc)

    # Create the API server manager
    manager = APIServerProcessManager(**api_server_args)

    try:
        # Verify manager initialization
        assert len(manager.processes) == 3

        # Create a result capture for the thread
220
        result: dict[str, Exception | None] = {"exception": None}
221
222
223

        def run_with_exception_capture():
            try:
224
225
226
                wait_for_completion_or_failure(
                    api_server_manager=manager, coordinator=mock_coordinator
                )
227
228
229
230
            except Exception as e:
                result["exception"] = e

        # Start a thread to run wait_for_completion_or_failure
231
        wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True)
232
233
234
235
236
237
238
239
240
241
        wait_thread.start()

        # Terminate the external process to trigger a failure
        time.sleep(0.2)
        external_proc.terminate()

        # Wait for the thread to detect the failure
        wait_thread.join(timeout=1.0)

        # The wait thread should have completed
242
243
244
        assert not wait_thread.is_alive(), (
            "wait_for_completion_or_failure thread still running"
        )
245
246
247
248

        # Verify that an exception was raised with appropriate error message
        assert result["exception"] is not None, "No exception was raised"
        error_message = str(result["exception"])
249
        assert "died with exit code" in error_message, (
250
            f"Unexpected error message: {error_message}"
251
252
        )
        assert "MockExternalProcess" in error_message, (
253
            f"Error doesn't mention external process: {error_message}"
254
        )
255
256
257

        # Verify that all API server processes were terminated as a result
        for i, proc in enumerate(manager.processes):
258
            assert not proc.is_alive(), f"API server process {i} was not terminated"
259
260
261

    finally:
        # Clean up
262
263
        manager.close()
        mock_coordinator.close()
264
        time.sleep(0.2)