test_sglang.py 17.3 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
# SPDX-License-Identifier: Apache-2.0

4
"""
5
6
7
8
9
Test Execution Times (Last Run: 2026-01-13):
- test_request_migration_sglang_aggregated: ~75s
- test_request_migration_sglang_prefill: N/A
- test_request_migration_sglang_kv_transfer: N/A
- test_request_migration_sglang_decode: ~75s
10
11
"""

12
13
14
15
16
17
18
import logging
import os
import shutil

import pytest

from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
19
from tests.utils.managed_process import ManagedProcess
20
from tests.utils.payloads import check_models_api
21
from tests.utils.port_utils import allocate_port, deallocate_port
22

23
24
# Customized utils for migration tests
from .utils import DynamoFrontendProcess, run_migration_test
25
26
27
28

logger = logging.getLogger(__name__)

pytestmark = [
29
    pytest.mark.fault_tolerance,
30
31
32
33
    pytest.mark.sglang,
    pytest.mark.gpu_1,
    pytest.mark.e2e,
    pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME),
34
35
36
    pytest.mark.parametrize(
        "migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"]
    ),
37
38
39
40
41
42
43
44
    pytest.mark.parametrize(
        "migration_max_seq_len",
        [
            pytest.param(None, id="max_seq_len_disabled"),
            pytest.param(1_000_000, id="max_seq_len_not_exceeded"),
            pytest.param(1, id="max_seq_len_exceeded"),
        ],
    ),
45
46
47
48
    pytest.mark.parametrize(
        "immediate_kill",
        [
            pytest.param(True, id="worker_failure"),
49
            pytest.param(False, id="graceful_shutdown"),
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        ],
    ),
    pytest.mark.parametrize(
        "request_api",
        [
            pytest.param("chat"),
            pytest.param(
                "completion",
                marks=pytest.mark.skip(reason="Behavior unverified yet"),
            ),
        ],
    ),
    pytest.mark.parametrize(
        "stream",
        [
            pytest.param(True, id="stream"),
            pytest.param(
                False,
                id="unary",
                marks=pytest.mark.skip(reason="Behavior unverified yet"),
            ),
        ],
    ),
73
    pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True),
74
75
76
77
]


class DynamoWorkerProcess(ManagedProcess):
78
79
80
81
82
83
84
85
86
87
88
    """Process manager for Dynamo worker with SGLang backend

    Supports both aggregated mode (single worker) and disaggregated mode
    (separate prefill and decode workers).

    Args:
        request: pytest request fixture
        worker_id: Unique identifier for the worker (e.g., "worker1", "worker2")
        frontend_port: Port where the frontend is running
        disagg_mode: None for aggregated, "prefill" or "decode" for disaggregated
    """
89

90
91
92
93
94
    def __init__(
        self,
        request,
        worker_id: str,
        frontend_port: int,
95
        disagg_mode: str | None = None,
96
    ):
97
        self.worker_id = worker_id
98
99
100
        self.system_port = allocate_port(9100)
        self.disagg_mode = disagg_mode

101
102
103
104
105
106
107
108
109
        command = [
            "python3",
            "-m",
            "dynamo.sglang",
            "--model-path",
            FAULT_TOLERANCE_MODEL_NAME,
            "--served-model-name",
            FAULT_TOLERANCE_MODEL_NAME,
            "--trust-remote-code",
110
111
112
113
            "--page-size",
            "16",
            "--tp",
            "1",
114
            "--mem-fraction-static",
115
            "0.3",
116
117
118
            "--context-length",
            "8192",
        ]
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        if disagg_mode is None:
            # Aggregated
            command.append("--skip-tokenizer-init")
        else:
            # Disaggregated
            command.extend(
                [
                    "--disaggregation-mode",
                    disagg_mode,
                    "--disaggregation-bootstrap-port",
                    f"1234{worker_id[-1]}",
                    "--host",
                    "0.0.0.0",
                    "--disaggregation-transfer-backend",
                    "nixl",
                ]
            )
            if disagg_mode == "prefill":
                command.extend(["--port", "40000"])
138

139
        # Set environment variables
140
        env = os.environ.copy()
141
        env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane")
142

143
        env["DYN_LOG"] = "debug"
144
145
146
147
148
        # Disable canary health check - these tests expect full control over requests
        # sent to the workers where canary health check intermittently sends dummy
        # requests to workers interfering with the test process which may cause
        # intermittent failures
        env["DYN_HEALTH_CHECK_ENABLED"] = "false"
149
        env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
150
        env["DYN_SYSTEM_PORT"] = str(self.system_port)
151
        env["DYN_HTTP_PORT"] = str(frontend_port)
152

153
154
155
        # Disable backend shutdown grace period for all migration tests
        env["DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"] = "0"

156
157
158
159
160
161
162
163
164
        # Configure health check based on worker type
        health_check_urls = [
            (f"http://localhost:{self.system_port}/health", self.is_ready)
        ]
        if disagg_mode is None or disagg_mode == "decode":
            health_check_urls.append(
                (f"http://localhost:{frontend_port}/v1/models", check_models_api)
            )

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        # TODO: Have the managed process take a command name explicitly to distinguish
        #       between processes started with the same command.
        log_dir = f"{request.node.name}_{worker_id}"

        # Clean up any existing log directory from previous runs
        try:
            shutil.rmtree(log_dir)
            logger.info(f"Cleaned up existing log directory: {log_dir}")
        except FileNotFoundError:
            # Directory doesn't exist, which is fine
            pass

        super().__init__(
            command=command,
            env=env,
180
            health_check_urls=health_check_urls,
181
182
            timeout=300,
            display_output=True,
183
            terminate_all_matching_process_names=False,
184
185
186
187
188
            stragglers=["SGLANG:EngineCore"],
            straggler_commands=["-m dynamo.sglang"],
            log_dir=log_dir,
        )

189
190
191
192
193
194
195
196
197
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Release allocated port when worker exits."""
        try:
            # system_port is a required parameter, always set in __init__
            deallocate_port(self.system_port)
        except Exception as e:
            logging.warning(f"Failed to release SGLang worker port: {e}")

        return super().__exit__(exc_type, exc_val, exc_tb)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

    def is_ready(self, response) -> bool:
        """Check the health of the worker process"""
        try:
            data = response.json()
            if data.get("status") == "ready":
                logger.info(f"{self.worker_id} status is ready")
                return True
            logger.warning(
                f"{self.worker_id} status is not ready: {data.get('status')}"
            )
        except ValueError:
            logger.warning(f"{self.worker_id} health response is not valid JSON")
        return False


214
@pytest.mark.timeout(230)  # 3x average
215
@pytest.mark.post_merge
216
217
218
219
220
221
def test_request_migration_sglang_aggregated(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
222
    migration_max_seq_len,
223
224
225
    immediate_kill,
    request_api,
    stream,
226
227
):
    """
228
    End-to-end test for aggregated worker request migration.
229

230
231
232
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
233
        migration_max_seq_len: Max sequence length for migration state tracking
234
235
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
236
237
    """

238
239
240
241
242
243
244
245
246
247
248
249
    # TODO(<LINEAR-ID>): Flaky on NATS transport — first-token delay routinely
    # exceeds the 6s threshold in utils.validate_response. Other parameter
    # combinations (including the TCP variant) are stable.
    if (
        migration_limit == 3
        and migration_max_seq_len is None
        and immediate_kill is True
        and request_api == "chat"
        and stream is True
        and request.getfixturevalue("request_plane") == "nats"
    ):
        pytest.skip(
Dmitry Tokarev's avatar
Dmitry Tokarev committed
250
            "Flaky on NATS transport: first-token delay > 6s threshold. OPS-4446"
251
252
        )

253
    # Step 1: Start the frontend
254
255
256
257
258
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
259
260
        logger.info("Frontend started successfully")

261
        # Step 2: Start 2 workers
262
        with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1:
263
264
            logger.info(f"Worker 1 PID: {worker1.get_pid()}")

265
266
267
            with DynamoWorkerProcess(
                request,
                "worker2",
268
                frontend.frontend_port,
269
            ) as worker2:
270
271
                logger.info(f"Worker 2 PID: {worker2.get_pid()}")

272
273
274
275
276
277
278
                # Step 3: Run migration test
                run_migration_test(
                    frontend,
                    worker1,
                    worker2,
                    receiving_pattern="New Request ID: ",
                    migration_limit=migration_limit,
279
                    migration_max_seq_len=migration_max_seq_len,
280
281
282
                    immediate_kill=immediate_kill,
                    use_chat_completion=(request_api == "chat"),
                    stream=stream,
283
284
                )

285

286
287
288
@pytest.mark.skip(reason="Cannot reliably migrate at Prefill that finish < 1 ms")
@pytest.mark.xfail(strict=False, reason="Prefill migration not yet supported")
@pytest.mark.timeout(230)  # 3x average
289
@pytest.mark.nightly
290
291
292
293
294
295
def test_request_migration_sglang_prefill(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
296
    migration_max_seq_len,
297
298
299
    immediate_kill,
    request_api,
    stream,
300
301
):
    """
302
    End-to-end test for prefill worker request migration in disaggregated mode.
303

304
    Setup: 1 decode worker + 2 prefill workers
305

306
307
308
309
310
311
312
313
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
    """

    # Step 1: Start the frontend
314
315
316
317
318
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
319
320
        logger.info("Frontend started successfully")

321
        # Step 2: Start decode worker first (required for prefill workers to connect)
322
323
        with DynamoWorkerProcess(
            request,
324
325
326
327
328
329
330
            "worker0",
            frontend.frontend_port,
            disagg_mode="decode",
        ) as decode_worker:
            logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")

            # Step 3: Start 2 prefill workers
331
332
            with DynamoWorkerProcess(
                request,
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
                "worker1",
                frontend.frontend_port,
                disagg_mode="prefill",
            ) as prefill1:
                logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    disagg_mode="prefill",
                ) as prefill2:
                    logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        prefill1,
                        prefill2,
                        receiving_pattern="New Request ID: ",
                        migration_limit=migration_limit,
354
                        migration_max_seq_len=migration_max_seq_len,
355
356
357
358
359
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        use_long_prompt=True,
                    )
360

361

362
363
@pytest.mark.skip(reason="KV cache transfer may fail")
@pytest.mark.timeout(230)  # 3x average
364
@pytest.mark.nightly
365
366
367
368
369
370
def test_request_migration_sglang_kv_transfer(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
371
    migration_max_seq_len,
372
373
374
    immediate_kill,
    request_api,
    stream,
375
376
):
    """
377
    End-to-end test for request migration during KV transfer in disaggregated mode.
378

379
    Setup: 1 prefill worker + 2 decode workers
380

381
382
383
384
385
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
386
387
    """

388
    # Step 1: Start the frontend
389
390
391
392
393
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
394
395
        logger.info("Frontend started successfully")

396
        # Step 2: Start prefill worker first
397
398
        with DynamoWorkerProcess(
            request,
399
400
401
402
403
404
405
            "worker0",
            frontend.frontend_port,
            disagg_mode="prefill",
        ) as prefill_worker:
            logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")

            # Step 3: Start 2 decode workers
406
407
            with DynamoWorkerProcess(
                request,
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
                "worker1",
                frontend.frontend_port,
                disagg_mode="decode",
            ) as decode1:
                logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    disagg_mode="decode",
                ) as decode2:
                    logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        decode1,
                        decode2,
                        receiving_pattern="New Request ID: ",
                        migration_limit=migration_limit,
429
                        migration_max_seq_len=migration_max_seq_len,
430
431
432
433
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        use_long_prompt=True,
434
435
436
                    )


437
@pytest.mark.timeout(230)  # 3x average
438
@pytest.mark.nightly
439
440
441
442
443
444
def test_request_migration_sglang_decode(
    request,
    runtime_services_dynamic_ports,
    set_ucx_tls_no_mm,
    predownload_models,
    migration_limit,
445
    migration_max_seq_len,
446
447
448
    immediate_kill,
    request_api,
    stream,
449
450
):
    """
451
    End-to-end test for decode worker request migration in disaggregated mode.
452

453
    Setup: 1 prefill worker + 2 decode workers
454

455
456
457
458
459
    Parameters:
        immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
        migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
        request_api: "chat" for chat completion API, "completion" for completion API
        stream: True for streaming, False for non-streaming
460
    """
461
462
463
464
    if not stream:
        pytest.skip(
            "Decode test requires streaming to wait for response before stopping worker"
        )
465

466
    # Step 1: Start the frontend
467
468
469
470
471
    with DynamoFrontendProcess(
        request,
        migration_limit=migration_limit,
        migration_max_seq_len=migration_max_seq_len,
    ) as frontend:
472
473
        logger.info("Frontend started successfully")

474
        # Step 2: Start prefill worker first
475
476
        with DynamoWorkerProcess(
            request,
477
478
479
480
481
482
483
            "worker0",
            frontend.frontend_port,
            disagg_mode="prefill",
        ) as prefill_worker:
            logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")

            # Step 3: Start 2 decode workers
484
485
            with DynamoWorkerProcess(
                request,
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
                "worker1",
                frontend.frontend_port,
                disagg_mode="decode",
            ) as decode1:
                logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")

                with DynamoWorkerProcess(
                    request,
                    "worker2",
                    frontend.frontend_port,
                    disagg_mode="decode",
                ) as decode2:
                    logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")

                    # Step 4: Run migration test
                    run_migration_test(
                        frontend,
                        decode1,
                        decode2,
                        receiving_pattern="New Request ID: ",
                        migration_limit=migration_limit,
507
                        migration_max_seq_len=migration_max_seq_len,
508
509
510
511
                        immediate_kill=immediate_kill,
                        use_chat_completion=(request_api == "chat"),
                        stream=stream,
                        wait_for_new_response_before_stop=True,
512
                    )