test_sglang.py 16.9 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
# SPDX-License-Identifier: Apache-2.0

4
"""
5
6
7
8
9
Test Execution Times (Last Run: 2026-01-13):
- test_request_migration_sglang_aggregated: ~75s
- test_request_migration_sglang_prefill: N/A
- test_request_migration_sglang_kv_transfer: N/A
- test_request_migration_sglang_decode: ~75s
10
11
"""

12
13
14
15
16
17
18
import logging
import os
import shutil

import pytest

from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
19
from tests.utils.managed_process import ManagedProcess
20
from tests.utils.payloads import check_models_api
21
from tests.utils.port_utils import allocate_port, deallocate_port
22

23
24
# Customized utils for migration tests
from .utils import DynamoFrontendProcess, run_migration_test
25
26
27
28

logger = logging.getLogger(__name__)

pytestmark = [
29
    pytest.mark.fault_tolerance,
30
31
32
33
    pytest.mark.sglang,
    pytest.mark.gpu_1,
    pytest.mark.e2e,
    pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME),
34
35
36
    pytest.mark.parametrize(
        "migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"]
    ),
37
38
39
40
41
42
43
44
    pytest.mark.parametrize(
        "migration_max_seq_len",
        [
            pytest.param(None, id="max_seq_len_disabled"),
            pytest.param(1_000_000, id="max_seq_len_not_exceeded"),
            pytest.param(1, id="max_seq_len_exceeded"),
        ],
    ),
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    pytest.mark.parametrize(
        "immediate_kill",
        [
            pytest.param(True, id="worker_failure"),
            pytest.param(
                False,
                id="graceful_shutdown",
                marks=pytest.mark.xfail(
                    strict=False, reason="SGLang graceful shutdown not yet implemented"
                ),
            ),
        ],
    ),
    pytest.mark.parametrize(
        "request_api",
        [
            pytest.param("chat"),
            pytest.param(
                "completion",
                marks=pytest.mark.skip(reason="Behavior unverified yet"),
            ),
        ],
    ),
    pytest.mark.parametrize(
        "stream",
        [
            pytest.param(True, id="stream"),
            pytest.param(
                False,
                id="unary",
                marks=pytest.mark.skip(reason="Behavior unverified yet"),
            ),
        ],
    ),
79
    pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True),
80
81
82
83
]


class DynamoWorkerProcess(ManagedProcess):
84
85
86
87
88
89
90
91
92
93
94
    """Process manager for Dynamo worker with SGLang backend

    Supports both aggregated mode (single worker) and disaggregated mode
    (separate prefill and decode workers).

    Args:
        request: pytest request fixture
        worker_id: Unique identifier for the worker (e.g., "worker1", "worker2")
        frontend_port: Port where the frontend is running
        disagg_mode: None for aggregated, "prefill" or "decode" for disaggregated
    """
95

96
97
98
99
100
    def __init__(
        self,
        request,
        worker_id: str,
        frontend_port: int,
101
        disagg_mode: str | None = None,
102
    ):
103
        self.worker_id = worker_id
104
105
106
        self.system_port = allocate_port(9100)
        self.disagg_mode = disagg_mode

107
108
109
110
111
112
113
114
115
        command = [
            "python3",
            "-m",
            "dynamo.sglang",
            "--model-path",
            FAULT_TOLERANCE_MODEL_NAME,
            "--served-model-name",
            FAULT_TOLERANCE_MODEL_NAME,
            "--trust-remote-code",
116
117
118
119
            "--page-size",
            "16",
            "--tp",
            "1",
120
            "--mem-fraction-static",
121
            "0.3",
122
123
124
            "--context-length",
            "8192",
        ]
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        if disagg_mode is None:
            # Aggregated
            command.append("--skip-tokenizer-init")
        else:
            # Disaggregated
            command.extend(
                [
                    "--disaggregation-mode",
                    disagg_mode,
                    "--disaggregation-bootstrap-port",
                    f"1234{worker_id[-1]}",
                    "--host",
                    "0.0.0.0",
                    "--disaggregation-transfer-backend",
                    "nixl",
                ]
            )
            if disagg_mode == "prefill":
                command.extend(["--port", "40000"])
144

145
        # Set environment variables
146
        env = os.environ.copy()
147
        env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
148

149
        env["DYN_LOG"] = "debug"
150
151
152
153
154
        # Disable canary health check - these tests expect full control over requests
        # sent to the workers where canary health check intermittently sends dummy
        # requests to workers interfering with the test process which may cause
        # intermittent failures
        env["DYN_HEALTH_CHECK_ENABLED"] = "false"
155
        env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
156
        env["DYN_SYSTEM_PORT"] = str(self.system_port)
157
        env["DYN_HTTP_PORT"] = str(frontend_port)
158

159
160
161
        # Disable backend shutdown grace period for all migration tests
        env["DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"] = "0"

162
163
164
165
166
167
168
169
170
        # Configure health check based on worker type
        health_check_urls = [
            (f"http://localhost:{self.system_port}/health", self.is_ready)
        ]
        if disagg_mode is None or disagg_mode == "decode":
            health_check_urls.append(
                (f"http://localhost:{frontend_port}/v1/models", check_models_api)
            )

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        # TODO: Have the managed process take a command name explicitly to distinguish
        #       between processes started with the same command.
        log_dir = f"{request.node.name}_{worker_id}"

        # Clean up any existing log directory from previous runs
        try:
            shutil.rmtree(log_dir)
            logger.info(f"Cleaned up existing log directory: {log_dir}")
        except FileNotFoundError:
            # Directory doesn't exist, which is fine
            pass

        super().__init__(
            command=command,
            env=env,
186
            health_check_urls=health_check_urls,
187
188
            timeout=300,
            display_output=True,
189
            terminate_all_matching_process_names=False,
190
191
192
193
194
            stragglers=["SGLANG:EngineCore"],
            straggler_commands=["-m dynamo.sglang"],
            log_dir=log_dir,
        )

195
196
197
198
199
200
201
202
203
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Release allocated port when worker exits."""
        try:
            # system_port is a required parameter, always set in __init__
            deallocate_port(self.system_port)
        except Exception as e:
            logging.warning(f"Failed to release SGLang worker port: {e}")

        return super().__exit__(exc_type, exc_val, exc_tb)
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

    def is_ready(self, response) -> bool:
        """Check the health of the worker process"""
        try:
            data = response.json()
            if data.get("status") == "ready":
                logger.info(f"{self.worker_id} status is ready")
                return True
            logger.warning(
                f"{self.worker_id} status is not ready: {data.get('status')}"
            )
        except ValueError:
            logger.warning(f"{self.worker_id} health response is not valid JSON")
        return False


220
@pytest.mark.timeout(230)  # 3x average
221
@pytest.mark.post_merge
222
223
224
225
226
227
def test_request_migration_sglang_aggregated(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
228
    migration_max_seq_len,
229
230
231
    immediate_kill,
    request_api,
    stream,
232
233
):
    """
234
    End-to-end test for aggregated worker request migration.
235

236
237
238
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
239
        migration_max_seq_len: Max sequence length for migration state tracking
240
241
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
242
243
    """

244
    # Step 1: Start the frontend
245
246
247
248
249
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
250
251
        logger.info("Frontend started successfully")

252
        # Step 2: Start 2 workers
253
        with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1:
254
255
            logger.info(f"Worker 1 PID: {worker1.get_pid()}")

256
257
258
            with DynamoWorkerProcess(
                request,
                "worker2",
259
                frontend.frontend_port,
260
            ) as worker2:
261
262
                logger.info(f"Worker 2 PID: {worker2.get_pid()}")

263
264
265
266
267
268
269
                # Step 3: Run migration test
                run_migration_test(
                    frontend,
                    worker1,
                    worker2,
                    receiving_pattern="New Request ID: ",
                    migration_limit=migration_limit,
270
                    migration_max_seq_len=migration_max_seq_len,
271
272
273
                    immediate_kill=immediate_kill,
                    use_chat_completion=(request_api == "chat"),
                    stream=stream,
274
275
                )

276

277
278
279
@pytest.mark.skip(reason="Cannot reliably migrate at Prefill that finish < 1 ms")
@pytest.mark.xfail(strict=False, reason="Prefill migration not yet supported")
@pytest.mark.timeout(230)  # 3x average
280
@pytest.mark.nightly
281
282
283
284
285
286
def test_request_migration_sglang_prefill(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
287
    migration_max_seq_len,
288
289
290
    immediate_kill,
    request_api,
    stream,
291
292
):
    """
293
    End-to-end test for prefill worker request migration in disaggregated mode.
294

295
    Setup: 1 decode worker + 2 prefill workers
296

297
298
299
300
301
302
303
304
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
    """

    # Step 1: Start the frontend
305
306
307
308
309
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
310
311
        logger.info("Frontend started successfully")

312
        # Step 2: Start decode worker first (required for prefill workers to connect)
313
314
        with DynamoWorkerProcess(
            request,
315
316
317
318
319
320
321
            "worker0",
            frontend.frontend_port,
            disagg_mode="decode",
        ) as decode_worker:
            logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")

            # Step 3: Start 2 prefill workers
322
323
            with DynamoWorkerProcess(
                request,
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
                "worker1",
                frontend.frontend_port,
                disagg_mode="prefill",
            ) as prefill1:
                logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    disagg_mode="prefill",
                ) as prefill2:
                    logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        prefill1,
                        prefill2,
                        receiving_pattern="New Request ID: ",
                        migration_limit=migration_limit,
345
                        migration_max_seq_len=migration_max_seq_len,
346
347
348
349
350
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        use_long_prompt=True,
                    )
351

352

353
354
@pytest.mark.skip(reason="KV cache transfer may fail")
@pytest.mark.timeout(230)  # 3x average
355
@pytest.mark.nightly
356
357
358
359
360
361
def test_request_migration_sglang_kv_transfer(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
362
    migration_max_seq_len,
363
364
365
    immediate_kill,
    request_api,
    stream,
366
367
):
    """
368
    End-to-end test for request migration during KV transfer in disaggregated mode.
369

370
    Setup: 1 prefill worker + 2 decode workers
371

372
373
374
375
376
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
377
378
    """

379
    # Step 1: Start the frontend
380
381
382
383
384
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
385
386
        logger.info("Frontend started successfully")

387
        # Step 2: Start prefill worker first
388
389
        with DynamoWorkerProcess(
            request,
390
391
392
393
394
395
396
            "worker0",
            frontend.frontend_port,
            disagg_mode="prefill",
        ) as prefill_worker:
            logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")

            # Step 3: Start 2 decode workers
397
398
            with DynamoWorkerProcess(
                request,
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
                "worker1",
                frontend.frontend_port,
                disagg_mode="decode",
            ) as decode1:
                logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    disagg_mode="decode",
                ) as decode2:
                    logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        decode1,
                        decode2,
                        receiving_pattern="New Request ID: ",
                        migration_limit=migration_limit,
420
                        migration_max_seq_len=migration_max_seq_len,
421
422
423
424
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        use_long_prompt=True,
425
426
427
                    )


428
@pytest.mark.timeout(230)  # 3x average
429
@pytest.mark.nightly
430
431
432
433
434
435
def test_request_migration_sglang_decode(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
436
    migration_max_seq_len,
437
438
439
    immediate_kill,
    request_api,
    stream,
440
441
):
    """
442
    End-to-end test for decode worker request migration in disaggregated mode.
443

444
    Setup: 1 prefill worker + 2 decode workers
445

446
447
448
449
450
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
451
    """
452
453
454
455
    if not stream:
        pytest.skip(
            "Decode test requires streaming to wait for response before stopping worker"
        )
456

457
    # Step 1: Start the frontend
458
459
460
461
462
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
463
464
        logger.info("Frontend started successfully")

465
        # Step 2: Start prefill worker first
466
467
        with DynamoWorkerProcess(
            request,
468
469
470
471
472
473
474
            "worker0",
            frontend.frontend_port,
            disagg_mode="prefill",
        ) as prefill_worker:
            logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")

            # Step 3: Start 2 decode workers
475
476
            with DynamoWorkerProcess(
                request,
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
                "worker1",
                frontend.frontend_port,
                disagg_mode="decode",
            ) as decode1:
                logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    disagg_mode="decode",
                ) as decode2:
                    logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        decode1,
                        decode2,
                        receiving_pattern="New Request ID: ",
                        migration_limit=migration_limit,
498
                        migration_max_seq_len=migration_max_seq_len,
499
500
501
502
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        wait_for_new_response_before_stop=True,
503
                    )